ai_scheduler/internal/tools/calculator.go

122 lines
3.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tools
import (
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"math"
)
// CalculatorTool 计算器工具
type CalculatorTool struct{}
// NewCalculatorTool 创建计算器工具
func NewCalculatorTool() *CalculatorTool {
return &CalculatorTool{}
}
// Name 返回工具名称
func (c *CalculatorTool) Name() string {
return "calculate"
}
// Description 返回工具描述
func (c *CalculatorTool) Description() string {
return "执行基本的数学运算,支持加减乘除和幂运算"
}
// Definition 返回工具定义
func (c *CalculatorTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: c.Name(),
Description: c.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"operation": map[string]interface{}{
"type": "string",
"description": "运算类型",
"enum": []string{"add", "subtract", "multiply", "divide", "power"},
},
"a": map[string]interface{}{
"type": "number",
"description": "第一个数字",
},
"b": map[string]interface{}{
"type": "number",
"description": "第二个数字",
},
},
"required": []string{"operation", "a", "b"},
},
},
}
}
// CalculateRequest 计算请求参数
type CalculateRequest struct {
Operation string `json:"operation"`
A float64 `json:"a"`
B float64 `json:"b"`
}
// CalculateResponse 计算响应
type CalculateResponse struct {
Operation string `json:"operation"`
A float64 `json:"a"`
B float64 `json:"b"`
Result float64 `json:"result"`
Expression string `json:"expression"`
}
// Execute 执行计算
func (c *CalculatorTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) {
var req CalculateRequest
if err := json.Unmarshal(args, &req); err != nil {
return nil, fmt.Errorf("invalid calculate request: %w", err)
}
var result float64
var expression string
switch req.Operation {
case "add":
result = req.A + req.B
expression = fmt.Sprintf("%.2f + %.2f = %.2f", req.A, req.B, result)
case "subtract":
result = req.A - req.B
expression = fmt.Sprintf("%.2f - %.2f = %.2f", req.A, req.B, result)
case "multiply":
result = req.A * req.B
expression = fmt.Sprintf("%.2f × %.2f = %.2f", req.A, req.B, result)
case "divide":
if req.B == 0 {
return nil, fmt.Errorf("division by zero is not allowed")
}
result = req.A / req.B
expression = fmt.Sprintf("%.2f ÷ %.2f = %.2f", req.A, req.B, result)
case "power":
result = math.Pow(req.A, req.B)
expression = fmt.Sprintf("%.2f ^ %.2f = %.2f", req.A, req.B, result)
default:
return nil, fmt.Errorf("unsupported operation: %s", req.Operation)
}
// 检查结果是否有效
if math.IsNaN(result) || math.IsInf(result, 0) {
return nil, fmt.Errorf("calculation resulted in invalid number")
}
return &CalculateResponse{
Operation: req.Operation,
A: req.A,
B: req.B,
Result: result,
Expression: expression,
}, nil
}