l_ai_knowledge/internal/application/service/metric/bleu.go

166 lines
3.8 KiB
Go

package metric
// references: https://github.com/waygo/bleu
// Package bleu implements the BLEU method, which is used to evaluate
// the quality of machine translation. [1]
//
// The code in this package was largely ported from the corresponding package
// in Python NLTK. [2]
//
// [1] Papineni, Kishore, et al. "BLEU: a method for automatic evaluation of
// machine translation." Proceedings of the 40th annual meeting on
// association for computational linguistics. Association for Computational
// Linguistics, 2002.
//
// [2] http://www.nltk.org/_modules/nltk/align/bleu.html
import (
"encoding/json"
"log"
"math"
"strings"
"knowlege-lsxd/internal/types"
)
type BLEUMetric struct {
smoothing bool
weights BLEUWeight
}
func NewBLEUMetric(smoothing bool, weights BLEUWeight) *BLEUMetric {
return &BLEUMetric{smoothing: smoothing, weights: weights}
}
type Sentence []string
type BLEUWeight []float64
var (
BLEU1Gram BLEUWeight = []float64{1.0, 0.0, 0.0, 0.0}
BLEU2Gram BLEUWeight = []float64{0.5, 0.5, 0.0, 0.0}
BLEU3Gram BLEUWeight = []float64{0.33, 0.33, 0.33, 0.0}
BLEU4Gram BLEUWeight = []float64{0.25, 0.25, 0.25, 0.25}
)
func (b *BLEUMetric) Compute(metricInput *types.MetricInput) float64 {
candidate := splitIntoWords(splitSentences(metricInput.GeneratedTexts))
references := []Sentence{splitIntoWords(splitSentences(metricInput.GeneratedGT))}
for i := range candidate {
candidate[i] = strings.ToLower(candidate[i])
}
for i := range references {
for u := range references[i] {
references[i][u] = strings.ToLower(references[i][u])
}
}
ps := make([]float64, len(b.weights))
for i := range b.weights {
ps[i] = b.modifiedPrecision(candidate, references, i+1)
}
s := 0.0
overlap := 0
for i := range b.weights {
w := b.weights[i]
pn := ps[i]
if pn > 0.0 {
overlap++
s += w * math.Log(pn)
}
}
if overlap == 0 {
return 0
}
bp := b.brevityPenalty(candidate, references)
return bp * math.Exp(s)
}
type phrase []string
func (p phrase) String() string {
b, err := json.Marshal(p)
if err != nil {
log.Fatal("encode error:", err)
}
return string(b)
}
func (b *BLEUMetric) getNphrase(s Sentence, n int) []phrase {
nphrase := []phrase{}
for i := 0; i < len(s)-n+1; i++ {
nphrase = append(nphrase, phrase(s[i:i+n]))
}
return nphrase
}
func (b *BLEUMetric) countNphrase(nphrase []phrase) map[string]int {
counts := map[string]int{}
for _, gram := range nphrase {
counts[gram.String()]++
}
return counts
}
func (b *BLEUMetric) modifiedPrecision(candidate Sentence, references []Sentence, n int) float64 {
nphrase := b.getNphrase(candidate, n)
if len(nphrase) == 0 {
return 0.0
}
counts := b.countNphrase(nphrase)
if len(counts) == 0 {
return 0.0
}
maxCounts := map[string]int{}
for i := range references {
referenceCounts := b.countNphrase(b.getNphrase(references[i], n))
for ngram := range counts {
if v, ok := maxCounts[ngram]; !ok {
maxCounts[ngram] = referenceCounts[ngram]
} else if v < referenceCounts[ngram] {
maxCounts[ngram] = referenceCounts[ngram]
}
}
}
clippedCounts := map[string]int{}
for ngram, count := range counts {
clippedCounts[ngram] = min(count, maxCounts[ngram])
}
smoothingFactor := 0.0
if b.smoothing {
smoothingFactor = 1.0
}
return (float64(sum(clippedCounts)) + smoothingFactor) / (float64(sum(counts)) + smoothingFactor)
}
func (b *BLEUMetric) brevityPenalty(candidate Sentence, references []Sentence) float64 {
c := len(candidate)
refLens := []int{}
for i := range references {
refLens = append(refLens, len(references[i]))
}
minDiffInd, minDiff := 0, -1
for i := range refLens {
if minDiff == -1 || abs(refLens[i]-c) < minDiff {
minDiffInd = i
minDiff = abs(refLens[i] - c)
}
}
r := refLens[minDiffInd]
if c > r {
return 1
}
return math.Exp(float64(1 - float64(r)/float64(c)))
}