Basic Slack Bot Support (#128)

This commit is contained in:
Chris Weaver 2023-07-03 14:26:33 -07:00 committed by GitHub
parent 381b3719c9
commit 2f54795631
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 378 additions and 119 deletions

View File

@ -1,6 +1,5 @@
import json
import os
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any
@ -15,15 +14,14 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.utils.logging import setup_logger
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
logger = setup_logger()
SLACK_LIMIT = 900
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
@ -31,61 +29,10 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
cursor: str | None = None
has_more = True
while has_more:
for result in call(cursor=cursor, limit=SLACK_LIMIT, **kwargs):
has_more = result.get("has_more", False)
cursor = result.get("response_metadata", {}).get("next_cursor", "")
results.append(cast(dict[str, Any], result))
return results
return paginated_call
def _make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 3
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response
if response.get("ok"):
return response
else:
raise SlackApiError("", response)
except SlackApiError as e:
if e.response["error"] == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
time.sleep(retry_after)
else:
# Raise the error for non-transient errors
raise
# If the code reaches this point, all retries have been exhausted
raise Exception(f"Max retries ({max_retries}) exceeded")
return rate_limited_call
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> list[dict[str, Any]]:
return _make_slack_api_call_paginated(_make_slack_api_rate_limited(call))(**kwargs)
return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:

View File

@ -1,6 +1,14 @@
import time
from collections.abc import Callable
from typing import Any
from typing import cast
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
def get_message_link(
event: dict[str, Any], workspace: str, channel_id: str | None = None
@ -13,3 +21,54 @@ def get_message_link(
return (
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
)
def make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
cursor: str | None = None
has_more = True
while has_more:
for result in call(cursor=cursor, limit=_SLACK_LIMIT, **kwargs):
has_more = result.get("has_more", False)
cursor = result.get("response_metadata", {}).get("next_cursor", "")
results.append(cast(dict[str, Any], result))
return results
return paginated_call
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 3
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response
if response.get("ok"):
return response
else:
raise SlackApiError("", response)
except SlackApiError as e:
if e.response["error"] == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
time.sleep(retry_after)
else:
# Raise the error for non-transient errors
raise
# If the code reaches this point, all retries have been exhausted
raise Exception(f"Max retries ({max_retries}) exceeded")
return rate_limited_call

View File

@ -2,8 +2,8 @@ from typing import Any
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
from danswer.direct_qa.question_answer import OpenAICompletionQA
from danswer.direct_qa.llm import OpenAIChatCompletionQA
from danswer.direct_qa.llm import OpenAICompletionQA
def get_default_backend_qa_model(

View File

@ -0,0 +1,82 @@
import time
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logging import setup_logger
logger = setup_logger()
def answer_question(question: QuestionRequest, user: User | None) -> QAResponse:
start_time = time.time()
query = question.query
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
error_msg = None
try:
answer, quotes = qa_model.answer_question(
query,
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}"
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(
answer=answer,
quotes=quotes,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=error_msg,
)

View File

@ -1,5 +1,5 @@
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import OpenAIQAModel
from danswer.direct_qa.llm import OpenAIQAModel
from openai.error import AuthenticationError
from openai.error import Timeout

View File

@ -0,0 +1,206 @@
import os
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.direct_qa.answer_question import answer_question
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logging import setup_logger
from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
logger = setup_logger()
_NUM_RETRIES = 3
_NUM_DOCS_TO_DISPLAY = 5
def _get_socket_client() -> SocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
if not app_token:
raise RuntimeError("DANSWER_BOT_SLACK_APP_TOKEN is not set")
bot_token = os.environ.get("DANSWER_BOT_SLACK_BOT_TOKEN")
if not bot_token:
raise RuntimeError("DANSWER_BOT_SLACK_BOT_TOKEN is not set")
return SocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=app_token,
web_client=WebClient(token=bot_token),
)
def _process_quotes(
quotes: dict[str, dict[str, str | int | None]] | None
) -> tuple[str | None, list[str]]:
if not quotes:
return None, []
quote_lines: list[str] = []
doc_identifiers: list[str] = []
for quote_dict in quotes.values():
doc_link = quote_dict.get("document_id")
doc_name = quote_dict.get("semantic_identifier")
if doc_link and doc_name and doc_name not in doc_identifiers:
doc_identifiers.append(str(doc_name))
quote_lines.append(f"- <{doc_link}|{doc_name}>")
if not quote_lines:
return None, []
return "\n".join(quote_lines), doc_identifiers
def _process_documents(
documents: list[SearchDoc] | None, already_displayed_doc_identifiers: list[str]
) -> str | None:
if not documents:
return None
top_documents = [
d
for d in documents
if d.semantic_identifier not in already_displayed_doc_identifiers
][:_NUM_DOCS_TO_DISPLAY]
top_documents_str = "\n".join(
[f"- <{d.link}|{d.semantic_identifier}>" for d in top_documents]
)
return "*Other potentially relevant documents:*\n" + top_documents_str
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
if req.type == "events_api":
# Acknowledge the request anyway
response = SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response)
# Ensure that the message is a new message + of expected type
event_type = req.payload.get("event", {}).get("type")
if event_type != "message":
logger.info(f"Ignoring non-message event of type '{event_type}'")
message_subtype = req.payload.get("event", {}).get("subtype")
if req.payload.get("event", {}).get("subtype") is not None:
# this covers things like channel_join, channel_leave, etc.
logger.info(
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
)
return
if req.payload.get("event", {}).get("bot_profile"):
logger.info("Ignoring message from bot")
return
msg = req.payload.get("event", {}).get("text")
thread_ts = req.payload.get("event", {}).get("ts")
if not msg:
logger.error("Unable to process empty message")
return
# TODO: message should be enqueued and processed elsewhere,
# but doing it here for now for simplicity
def _get_answer(question: QuestionRequest) -> QAResponse | None:
try:
answer = answer_question(question=question, user=None)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
except Exception as e:
logger.error(f"Unable to process message: {e}")
return None
answer = None
for _ in range(_NUM_RETRIES):
answer = _get_answer(
QuestionRequest(
query=req.payload.get("event", {}).get("text"),
collection=QDRANT_DEFAULT_COLLECTION,
use_keyword=False,
filters=None,
offset=None,
)
)
if answer:
break
if not answer:
logger.error(
f"Unable to process message - did not successfully answer in {_NUM_RETRIES} attempts"
)
return
if not answer.answer:
logger.error(f"Unable to process message - no answer found")
return
# convert raw response into "nicely" formatted Slack message
quote_str, doc_identifiers = _process_quotes(answer.quotes)
top_documents_str = _process_documents(answer.top_ranked_docs, doc_identifiers)
if quote_str:
text = f"{answer.answer}\n\n*Sources:*\n{quote_str}\n\n{top_documents_str}"
else:
text = f"{answer.answer}\n\n*Warning*: no sources were quoted for this answer, so it may be unreliable 😔\n\n{top_documents_str}"
def _respond_in_thread(
channel: str,
text: str,
thread_ts: str,
) -> str | None:
slack_call = make_slack_api_rate_limited(client.web_client.chat_postMessage)
response = slack_call(
channel=channel,
text=text,
thread_ts=thread_ts,
)
if not response.get("ok"):
return f"Unable to post message: {response}"
return None
successfully_answered = False
for _ in range(_NUM_RETRIES):
error_msg = _respond_in_thread(
channel=req.payload.get("event", {}).get("channel"),
text=text,
thread_ts=thread_ts,
)
if error_msg:
logger.error(error_msg)
else:
successfully_answered = True
break
if not successfully_answered:
logger.error(
f"Unable to process message - could not respond in slack in {_NUM_RETRIES} attempts"
)
return
logger.info(f"Successfully processed message with ts: '{thread_ts}'")
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
socket_client = _get_socket_client()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info("Listening for messages from Slack...")
socket_client.connect()
# Just not to stop this process
from threading import Event
Event().wait()

View File

@ -38,7 +38,7 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import User
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
from danswer.direct_qa.question_answer import get_openai_api_key
from danswer.direct_qa.llm import get_openai_api_key
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey

View File

@ -105,6 +105,7 @@ class QAResponse(SearchResponse):
quotes: dict[str, dict[str, str | int | None]] | None
predicted_flow: QueryFlow
predicted_search: SearchType
error_msg: str | None = None
class UserByEmail(BaseModel):

View File

@ -9,7 +9,8 @@ from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import get_json_line
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.llm import get_json_line
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
@ -85,62 +86,7 @@ def keyword_search(
def direct_qa(
question: QuestionRequest, user: User = Depends(current_user)
) -> QAResponse:
start_time = time.time()
query = question.query
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
try:
answer, quotes = qa_model.answer_question(
query,
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
)
except Exception:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(
answer=answer,
quotes=quotes,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
return answer_question(question=question, user=user)
@router.post("/stream-direct-qa")

View File

@ -16,3 +16,15 @@ stdout_logfile=/var/log/file_deletion.log
redirect_stderr=true
stdout_logfile_maxbytes=52428800
autorestart=true
# Listens for slack messages and responds with answers
# for all channels that the DanswerBot has been added to.
# If not setup, this will just fail 5 times and then stop.
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
[program:slack_bot_listener]
command=python danswer/listeners/slack_listener.py
stdout_logfile=/var/log/slack_bot_listener.log
redirect_stderr=true
stdout_logfile_maxbytes=52428800
autorestart=true
startretries=5

View File

@ -2,8 +2,8 @@ import textwrap
import unittest
from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.question_answer import match_quotes_to_docs
from danswer.direct_qa.question_answer import separate_answer_quotes
from danswer.direct_qa.llm import match_quotes_to_docs
from danswer.direct_qa.llm import separate_answer_quotes
class TestQAPostprocessing(unittest.TestCase):

View File

@ -27,6 +27,12 @@ ENABLE_OAUTH=True
GOOGLE_OAUTH_CLIENT_ID=
GOOGLE_OAUTH_CLIENT_SECRET=
# If you want to setup a slack bot to answer questions automatically in Slack
# channels it is added to, you must specify the below.
# More information in the guide here: https://docs.danswer.dev/slack_bot_setup
DANSWER_BOT_SLACK_APP_TOKEN=
DANSWER_BOT_SLACK_BOT_TOKEN=
# Used to generate values for security verification, use a random string
SECRET=