package metric import ( "testing" "knowlege-lsxd/internal/types" ) func TestMAPMetric_Compute(t *testing.T) { tests := []struct { name string input *types.MetricInput expected float64 }{ { name: "total match", input: &types.MetricInput{ RetrievalGT: [][]int{{2, 4, 6}}, RetrievalIDs: []int{2, 4, 6}, }, expected: 1.0, }, { name: "no match", input: &types.MetricInput{ RetrievalGT: [][]int{{1, 2}}, RetrievalIDs: []int{3, 4}, }, expected: 0.0, }, { name: "partial match", input: &types.MetricInput{ RetrievalGT: [][]int{{1, 2, 3}}, RetrievalIDs: []int{2, 5, 1, 3}, }, // AP = (1/1 + 2/3 + 3/4)/3 ≈ 0.80555555 expected: 0.8055555555555555, }, { name: "empty ground truth", input: &types.MetricInput{ RetrievalGT: [][]int{}, RetrievalIDs: []int{1, 2}, }, expected: 0.0, }, { name: "multiple queries", input: &types.MetricInput{ RetrievalGT: [][]int{ {1, 2}, {3, 4}, }, RetrievalIDs: []int{1, 3, 2, 4}, }, // Query1 AP: (1/1 + 2/3)/2 ≈ 0.8333 // Query2 AP: (1/2 + 2/4)/2 = 0.5 // MAP: (0.8333 + 0.5)/2 ≈ 0.6667 expected: 0.6666666666666666, }, } metric := NewMAPMetric() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := metric.Compute(tt.input) if !almostEqual(got, tt.expected, 1e-6) { t.Errorf("Compute() = %v, want %v", got, tt.expected) } }) } } // Helper function to compare floating point numbers with tolerance func almostEqual(a, b, tolerance float64) bool { if a == b { return true } diff := a - b if diff < 0 { diff = -diff } return diff < tolerance }