mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
updated tool runner + displays
This commit is contained in:
parent
2beffdaa6e
commit
3170430673
Binary file not shown.
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 27 KiB |
@ -709,7 +709,10 @@ def stream_chat_message_objects(
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_arguments={
|
||||
k: v if not isinstance(v, bytes) else v.decode("utf-8")
|
||||
for k, v in tool_result.tool_args.items()
|
||||
},
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
@ -814,11 +817,12 @@ def stream_chat_message_objects(
|
||||
)
|
||||
elif packet.id == GRAPHING_RESPONSE_ID:
|
||||
graph_generation = cast(GraphingResponse, packet.response)
|
||||
yield graph_generation
|
||||
|
||||
yield GraphGenerationDisplay(
|
||||
file_id=graph_generation.extra_graph_display.file_id,
|
||||
line_graph=graph_generation.extra_graph_display.line_graph,
|
||||
)
|
||||
# yield GraphGenerationDisplay(
|
||||
# file_id=graph_generation.extra_graph_display.file_id,
|
||||
# line_graph=graph_generation.extra_graph_display.line_graph,
|
||||
# )
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
|
@ -1,19 +1,22 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||
class DateTimeAndBytesEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings and bytes to base64."""
|
||||
|
||||
def default(self, obj: Any) -> Any:
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode("utf-8")
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def get_json_line(
|
||||
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
|
||||
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeAndBytesEncoder
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary to a JSON string with datetime handling, and add a newline.
|
||||
|
@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -29,7 +28,7 @@ from danswer.prompts.chat_prompts import (
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.tools.graphing.models import GraphingError
|
||||
from danswer.tools.graphing.models import GraphingResponse
|
||||
from danswer.tools.graphing.models import GraphingResult
|
||||
from danswer.tools.graphing.models import GraphType
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -276,6 +275,7 @@ class GraphingTool(Tool):
|
||||
print(kwargs)
|
||||
|
||||
file_content = kwargs["filename"]
|
||||
file_content = file_content.decode("utf-8")
|
||||
csv_file = StringIO(file_content)
|
||||
df = pd.read_csv(csv_file)
|
||||
|
||||
@ -320,13 +320,13 @@ class GraphingTool(Tool):
|
||||
ax = fig.gca() # Get the current Axes
|
||||
|
||||
plot_data = None
|
||||
plot_type = None
|
||||
plot_type: GraphType | None = None
|
||||
if self.is_line_plot(ax):
|
||||
plot_data = self.extract_line_plot_data(ax)
|
||||
plot_type = "line"
|
||||
plot_type = GraphType.LINE_GRAPH
|
||||
elif self.is_bar_plot(ax):
|
||||
plot_data = self.extract_bar_plot_data(ax)
|
||||
plot_type = "bar"
|
||||
plot_type = GraphType.BAR_CHART
|
||||
|
||||
if plot_data:
|
||||
plot_data_file = os.path.join(self.output_dir, "plot_data.json")
|
||||
@ -350,23 +350,19 @@ class GraphingTool(Tool):
|
||||
|
||||
buf = BytesIO()
|
||||
fig.savefig(buf, format="png", bbox_inches="tight") # type: ignore
|
||||
img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
with open("aaa garp.png", "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
graph_result = GraphingResult(image=img_base64, plot_data=plot_data)
|
||||
print("da plot type iza")
|
||||
print(plot_type)
|
||||
print("\n\n\n")
|
||||
print(code)
|
||||
response = GraphingResponse(
|
||||
graph_result=graph_result,
|
||||
extra_graph_display={
|
||||
"file_id": file_id,
|
||||
"line_graph": plot_type == "line",
|
||||
},
|
||||
yield ToolResponse(
|
||||
id=GRAPHING_RESPONSE_ID,
|
||||
response=GraphingResponse(
|
||||
file_id=str(file_id),
|
||||
graph_type=plot_type.value
|
||||
if plot_type
|
||||
else None, # Use .value to get the string
|
||||
plot_data=plot_data, # Pass the dictionary directly, not as a JSON string
|
||||
),
|
||||
)
|
||||
yield ToolResponse(id=GRAPHING_RESPONSE_ID, response=response)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error generating graph: {str(e)}\n{traceback.format_exc()}"
|
||||
|
@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -8,21 +9,15 @@ class GraphGenerationDisplay(BaseModel):
|
||||
line_graph: bool
|
||||
|
||||
|
||||
class GraphType(Enum):
|
||||
class GraphType(str, Enum):
|
||||
BAR_CHART = "bar_chart"
|
||||
LINE_GRAPH = "line_graph"
|
||||
# SCATTER_PLOT = "scatter_plot"
|
||||
# PIE_CHART = "pie_chart"
|
||||
# HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
class GraphingResponse(BaseModel):
|
||||
revised_query: str | None = None
|
||||
file_id: str
|
||||
graph_tye: GraphType
|
||||
|
||||
|
||||
# graph_display: GraphGenerationDisplay | None
|
||||
plot_data: dict[str, Any] | None
|
||||
graph_type: GraphType
|
||||
|
||||
|
||||
class GraphingError(BaseModel):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
@ -27,6 +28,7 @@ class ToolRunner:
|
||||
if self._tool_responses is not None:
|
||||
print("prev")
|
||||
print(self._tool_responses)
|
||||
|
||||
yield from self._tool_responses
|
||||
return
|
||||
|
||||
@ -35,6 +37,11 @@ class ToolRunner:
|
||||
print(self.tool.name)
|
||||
|
||||
for tool_response in self.tool.run(llm=self._llm, **self.args):
|
||||
if isinstance(tool_response.response, bytes):
|
||||
tool_response.response = base64.b64encode(
|
||||
tool_response.response
|
||||
).decode("utf-8")
|
||||
|
||||
print("tool response")
|
||||
yield tool_response
|
||||
tool_responses.append(tool_response)
|
||||
|
@ -1 +1 @@
|
||||
{"data": [{"x": 0.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 1.0, "y": 46.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 2.0, "y": 26.41338, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 3.0, "y": 1.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 4.0, "y": 23.5, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 5.0, "y": 46.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 6.0, "y": 68.5, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 7.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}], "title": "Summary Statistics of Id", "xlabel": "Statistic", "ylabel": "Value", "xticks": [0, 1, 2, 3, 4, 5, 6, 7], "xticklabels": ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]}
|
||||
{"data": [{"x": 0.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 1.0, "y": 46.0, "width": 0.7999999999999999, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 2.0, "y": 26.41338, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 3.0, "y": 1.0, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 4.0, "y": 23.5, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 5.0, "y": 46.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 6.0, "y": 68.5, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 7.0, "y": 91.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}], "title": "Summary Statistics of Id Column", "xlabel": "Statistic", "ylabel": "Value", "xticks": [0, 1, 2, 3, 4, 5, 6, 7], "xticklabels": ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]}
|
@ -154,7 +154,6 @@ export default function AddConnector({
|
||||
initialValues={createConnectorInitialValues(connector)}
|
||||
validationSchema={createConnectorValidationSchema(connector)}
|
||||
onSubmit={async (values) => {
|
||||
console.log(" Iam submiing the connector");
|
||||
const {
|
||||
name,
|
||||
groups,
|
||||
|
@ -1143,6 +1143,7 @@ export function ChatPage({
|
||||
continue;
|
||||
}
|
||||
|
||||
console.log(packet);
|
||||
if (!initialFetchDetails) {
|
||||
if (!Object.hasOwn(packet, "user_message_id")) {
|
||||
console.error(
|
||||
|
@ -77,13 +77,6 @@ const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
INTERNET_SEARCH_TOOL_NAME,
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
];
|
||||
import plotDataJson from "./linechart.json";
|
||||
import barChartDataJson from "./barchart_data.json";
|
||||
import polarChartDataJson from "./polar_plot_data.json";
|
||||
import { JSONUpload } from "./JSONUpload";
|
||||
import { ImageDisplay } from "@/components/chat_display/graphs/ImageDisplay";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
|
||||
function FileDisplay({
|
||||
files,
|
||||
@ -164,15 +157,18 @@ function FileDisplay({
|
||||
</>
|
||||
);
|
||||
}
|
||||
export interface graph {
|
||||
file_id: string;
|
||||
line: boolean;
|
||||
|
||||
enum GraphType {
|
||||
BAR_CHART = "bar_chart",
|
||||
LINE_GRAPH = "line_graph",
|
||||
}
|
||||
|
||||
export interface GraphChunk {
|
||||
file_id: string;
|
||||
line_graph: boolean;
|
||||
plot_data: Record<string, any> | null;
|
||||
graph_type: GraphType | null;
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
hasChildAI,
|
||||
hasParentAI,
|
||||
@ -208,7 +204,7 @@ export const AIMessage = ({
|
||||
shared?: boolean;
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
graphs?: graph[];
|
||||
graphs?: GraphChunk[];
|
||||
isActive?: boolean;
|
||||
continueGenerating?: () => void;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
@ -236,6 +232,7 @@ export const AIMessage = ({
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
setPopup?: (popupSpec: PopupSpec | null) => void;
|
||||
}) => {
|
||||
console.log(toolCall);
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
|
||||
const toolCallGenerating = toolCall && !toolCall.tool_result;
|
||||
@ -439,7 +436,7 @@ export const AIMessage = ({
|
||||
</div>
|
||||
)}
|
||||
{graphs.map((graph, ind) => {
|
||||
return graph.line ? (
|
||||
return graph.graph_type === GraphType.LINE_GRAPH ? (
|
||||
<ModalChartWrapper
|
||||
key={ind}
|
||||
chartType="line"
|
||||
@ -460,6 +457,17 @@ export const AIMessage = ({
|
||||
|
||||
{content || files ? (
|
||||
<>
|
||||
{toolCall?.tool_name == "create_graph" && (
|
||||
<ModalChartWrapper
|
||||
key={0}
|
||||
chartType="line"
|
||||
fileId={toolCall?.tool_result?.file_id}
|
||||
>
|
||||
<LineChartDisplay
|
||||
fileId={toolCall?.tool_result?.file_id}
|
||||
/>
|
||||
</ModalChartWrapper>
|
||||
)}
|
||||
<FileDisplay files={files || []} />
|
||||
|
||||
{typeof content === "string" ? (
|
||||
|
@ -124,7 +124,6 @@ export const CsvSection = ({
|
||||
setFadeIn(false);
|
||||
}
|
||||
}, [isLoading]);
|
||||
console.log("rerendering");
|
||||
|
||||
const downloadFile = () => {
|
||||
if (!fileId) return;
|
||||
|
@ -41,6 +41,7 @@ export function BarChartDisplay({ fileId }: { fileId: string }) {
|
||||
if (!barPlotData) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
console.log("IN THE FUNCTION");
|
||||
|
||||
// Transform data to match Recharts expected format
|
||||
const transformedData = barPlotData.data.map((point, index) => ({
|
||||
|
@ -190,6 +190,8 @@ export function LineChartDisplay({ fileId }: { fileId: string }) {
|
||||
throw new Error("Failed to fetch plot data");
|
||||
}
|
||||
const plotDataJson: PlotData = await response.json();
|
||||
console.log("plot data");
|
||||
console.log(plotDataJson);
|
||||
|
||||
const transformedData: ChartDataPoint[] = plotDataJson.data[0].x.map(
|
||||
(x, index) => ({
|
||||
@ -212,6 +214,8 @@ export function LineChartDisplay({ fileId }: { fileId: string }) {
|
||||
console.error("Error fetching plot data:", error);
|
||||
}
|
||||
};
|
||||
console.log("chartData");
|
||||
console.log(chartData);
|
||||
|
||||
return (
|
||||
<div className="w-full h-full">
|
||||
|
Loading…
x
Reference in New Issue
Block a user