l_ai_knowledge/internal/application/service/dataset.go

244 lines
6.3 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"github.com/parquet-go/parquet-go"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// DatasetService provides operations for working with datasets
type DatasetService struct{}
// NewDatasetService creates a new DatasetService instance
func NewDatasetService() interfaces.DatasetService {
return &DatasetService{}
}
// TextInfo represents text data with ID in parquet format
type TextInfo struct {
ID int64 `parquet:"id"` // Unique identifier
Text string `parquet:"text"` // Text content
}
// RelsInfo represents question-passage relations in parquet format
type RelsInfo struct {
QID int64 `parquet:"qid"` // Question ID
PID int64 `parquet:"pid"` // Passage ID
}
// QaInfo represents question-answer relations in parquet format
type QaInfo struct {
QID int64 `parquet:"qid"` // Question ID
AID int64 `parquet:"aid"` // Answer ID
}
// GetDatasetByID retrieves QA pairs from dataset by ID
func (d *DatasetService) GetDatasetByID(ctx context.Context, datasetID string) ([]*types.QAPair, error) {
logger.Info(ctx, "Start getting dataset by ID")
logger.Infof(ctx, "Getting dataset with ID: %s", datasetID)
dataset := DefaultDataset()
dataset.PrintStats(ctx)
qaPairs := dataset.Iterate()
logger.Infof(ctx, "Retrieved %d QA pairs from dataset", len(qaPairs))
return qaPairs, nil
}
// DefaultDataset loads and initializes the default dataset from parquet files
func DefaultDataset() dataset {
datasetDir := "./dataset/samples"
queries, err := loadParquet[TextInfo](fmt.Sprintf("%s/queries.parquet", datasetDir))
if err != nil {
panic(err)
}
corpus, err := loadParquet[TextInfo](fmt.Sprintf("%s/corpus.parquet", datasetDir))
if err != nil {
panic(err)
}
answers, err := loadParquet[TextInfo](fmt.Sprintf("%s/answers.parquet", datasetDir))
if err != nil {
panic(err)
}
qrels, err := loadParquet[RelsInfo](fmt.Sprintf("%s/qrels.parquet", datasetDir))
if err != nil {
panic(err)
}
qas, err := loadParquet[QaInfo](fmt.Sprintf("%s/qas.parquet", datasetDir))
if err != nil {
panic(err)
}
res := dataset{
queries: make(map[int64]string), // qid -> question text
corpus: make(map[int64]string), // pid -> passage text
answers: make(map[int64]string), // aid -> answer text
qrels: make(map[int64][]int64), // qid -> list of pid
qas: make(map[int64]int64), // qid -> aid
}
for _, qi := range queries {
res.queries[qi.ID] = qi.Text
}
for _, ci := range corpus {
res.corpus[ci.ID] = ci.Text
}
for _, ai := range answers {
res.answers[ai.ID] = ai.Text
}
for _, ri := range qrels {
res.qrels[ri.QID] = append(res.qrels[ri.QID], ri.PID)
}
for _, qi := range qas {
res.qas[qi.QID] = qi.AID
}
return res
}
// dataset represents the in-memory dataset structure
type dataset struct {
queries map[int64]string // qid -> question text
corpus map[int64]string // pid -> passage text
answers map[int64]string // aid -> answer text
qrels map[int64][]int64 // qid -> list of related pids
qas map[int64]int64 // qid -> aid
}
// Iterate generates QA pairs from the dataset
func (d *dataset) Iterate() []*types.QAPair {
var pairs []*types.QAPair
for qid, question := range d.queries {
// Get answer info
aid, hasAnswer := d.qas[qid]
answer := ""
if hasAnswer {
answer = d.answers[aid]
}
// Get related passages
pids := d.qrels[qid]
var pidStr []int
for _, pid := range pids {
pidStr = append(pidStr, int(pid))
}
var passages []string
for _, pid := range pids {
passages = append(passages, d.corpus[pid])
}
pairs = append(pairs, &types.QAPair{
QID: int(qid),
Question: question,
PIDs: pidStr,
Passages: passages,
AID: int(aid),
Answer: answer,
})
}
return pairs
}
// GetContextForQID retrieves context passages for a given question ID
func (d *dataset) GetContextForQID(qid int64) ([]string, error) {
pids, ok := d.qrels[qid]
if !ok {
return nil, errors.New("question ID not found")
}
var contextParts []string
for _, pid := range pids {
if text, exists := d.corpus[pid]; exists {
contextParts = append(contextParts, text)
}
}
return contextParts, nil
}
// PrintStats prints dataset statistics to the logger
func (d *dataset) PrintStats(ctx context.Context) {
logger.Infof(ctx, "QA System Statistics:")
logger.Infof(ctx, "- Total queries: %d", len(d.queries))
logger.Infof(ctx, "- Total corpus passages: %d", len(d.corpus))
logger.Infof(ctx, "- Total answers: %d", len(d.answers))
// Calculate average passages per query
totalRelations := 0
for _, pids := range d.qrels {
totalRelations += len(pids)
}
avgPassages := float64(totalRelations) / float64(len(d.qrels))
logger.Infof(ctx, "- Average passages per query: %.2f", avgPassages)
// Calculate coverage
coveredQueries := len(d.qas)
coverage := float64(coveredQueries) / float64(len(d.queries)) * 100
logger.Infof(ctx, "- Answer coverage: %.2f%% (%d/%d)", coverage, coveredQueries, len(d.queries))
}
// PrintRandomQA prints a random question with its related passages and answer
func (d *dataset) PrintRandomQA() error {
// Get a random qid
var qid int64
for k := range d.qas {
qid = k
break
}
if qid == 0 {
return errors.New("no questions available")
}
// Get question text
question, ok := d.queries[qid]
if !ok {
return fmt.Errorf("question %d not found", qid)
}
// Get answer info
aid, ok := d.qas[qid]
if !ok {
return fmt.Errorf("answer for question %d not found", qid)
}
answer, ok := d.answers[aid]
if !ok {
return fmt.Errorf("answer %d not found", aid)
}
// Print formatted QA
fmt.Println("===== Random QA =====")
fmt.Printf("QID: %d\n", qid)
fmt.Printf("Question: %s\n", question)
// Print passages if available
if pids, exists := d.qrels[qid]; exists && len(pids) > 0 {
fmt.Println("\nRelated passages:")
for i, pid := range pids {
if text, exists := d.corpus[pid]; exists {
fmt.Printf("\nPassage %d (PID: %d):\n%s\n", i+1, pid, text)
}
}
} else {
fmt.Println("\nNo related passages found")
}
// Print answer
fmt.Printf("\nAnswer (AID: %d):\n%s\n", aid, answer)
return nil
}
// loadParquet loads data from parquet file into specified type
func loadParquet[T any](filePath string) ([]T, error) {
rows, err := parquet.ReadFile[T](filePath)
if err != nil {
return nil, err
}
return rows, nil
}