71 lines
1.6 KiB
Go
71 lines
1.6 KiB
Go
package metric
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"knowlege-lsxd/internal/types"
|
|
)
|
|
|
|
func TestRecallMetric_Compute(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input *types.MetricInput
|
|
expected float64
|
|
}{
|
|
{
|
|
name: "perfect recall - all ground truth retrieved",
|
|
input: &types.MetricInput{
|
|
RetrievalGT: [][]int{{1, 2, 3}},
|
|
RetrievalIDs: []int{1, 2, 3, 4},
|
|
},
|
|
expected: 1.0,
|
|
},
|
|
{
|
|
name: "partial recall - some ground truth retrieved",
|
|
input: &types.MetricInput{
|
|
RetrievalGT: [][]int{{1, 2, 3}, {4, 5}},
|
|
RetrievalIDs: []int{1, 4, 6},
|
|
},
|
|
// 命中2个ground truth集合中的元素(a和d)
|
|
expected: 0.41666666666666663, // (1/3 + 1/2) / 2 = 0.41666 (每个ground truth集合只要命中一个就算召回)
|
|
|
|
},
|
|
{
|
|
name: "no recall - no ground truth retrieved",
|
|
input: &types.MetricInput{
|
|
RetrievalGT: [][]int{{1, 2, 3}},
|
|
RetrievalIDs: []int{4, 5, 6},
|
|
},
|
|
expected: 0.0,
|
|
},
|
|
{
|
|
name: "empty retrieval list",
|
|
input: &types.MetricInput{
|
|
RetrievalGT: [][]int{{1, 2, 3}},
|
|
RetrievalIDs: []int{},
|
|
},
|
|
expected: 0.0,
|
|
},
|
|
{
|
|
name: "multiple ground truth sets",
|
|
input: &types.MetricInput{
|
|
RetrievalGT: [][]int{{1, 2}, {3, 4}, {5, 6}},
|
|
RetrievalIDs: []int{1, 3, 7},
|
|
},
|
|
// 命中了前两个ground truth集合(a和c)
|
|
expected: 0.3333333333333333, // 1/3≈0.333...
|
|
},
|
|
}
|
|
|
|
rm := NewRecallMetric()
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := rm.Compute(tt.input)
|
|
if got != tt.expected {
|
|
t.Errorf("Compute() = %v, want %v", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|