l_ai_knowledge/internal/application/service/metric/recall_test.go

71 lines
1.6 KiB
Go

package metric
import (
"testing"
"knowlege-lsxd/internal/types"
)
func TestRecallMetric_Compute(t *testing.T) {
tests := []struct {
name string
input *types.MetricInput
expected float64
}{
{
name: "perfect recall - all ground truth retrieved",
input: &types.MetricInput{
RetrievalGT: [][]int{{1, 2, 3}},
RetrievalIDs: []int{1, 2, 3, 4},
},
expected: 1.0,
},
{
name: "partial recall - some ground truth retrieved",
input: &types.MetricInput{
RetrievalGT: [][]int{{1, 2, 3}, {4, 5}},
RetrievalIDs: []int{1, 4, 6},
},
// 命中2个ground truth集合中的元素(a和d)
expected: 0.41666666666666663, // (1/3 + 1/2) / 2 = 0.41666 (每个ground truth集合只要命中一个就算召回)
},
{
name: "no recall - no ground truth retrieved",
input: &types.MetricInput{
RetrievalGT: [][]int{{1, 2, 3}},
RetrievalIDs: []int{4, 5, 6},
},
expected: 0.0,
},
{
name: "empty retrieval list",
input: &types.MetricInput{
RetrievalGT: [][]int{{1, 2, 3}},
RetrievalIDs: []int{},
},
expected: 0.0,
},
{
name: "multiple ground truth sets",
input: &types.MetricInput{
RetrievalGT: [][]int{{1, 2}, {3, 4}, {5, 6}},
RetrievalIDs: []int{1, 3, 7},
},
// 命中了前两个ground truth集合(a和c)
expected: 0.3333333333333333, // 1/3≈0.333...
},
}
rm := NewRecallMetric()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rm.Compute(tt.input)
if got != tt.expected {
t.Errorf("Compute() = %v, want %v", got, tt.expected)
}
})
}
}