382 lines
12 KiB
Python
382 lines
12 KiB
Python
"""
|
|
QA Dataset Sampling Tool
|
|
|
|
```
|
|
pip install pandas pyarrow
|
|
pip install openai
|
|
```
|
|
|
|
# 采样数据
|
|
python dataset/qa_dataset.py sample \
|
|
--queries ~/dataset/mmarco-queries.parquet \
|
|
--corpus ~/dataset/mmarco-corpus.parquet \
|
|
--qrels ~/dataset/mmarco-qrels.parquet \
|
|
--nq 100 \
|
|
--output_dir ./dataset/samples
|
|
|
|
# 生成答案(基于采样结果)
|
|
python dataset/qa_dataset.py generate \
|
|
--input_dir ./dataset/samples \
|
|
--output_dir ./dataset/samples
|
|
|
|
# 展示结果
|
|
python dataset/qa_dataset.py show \
|
|
--input_dir ./dataset/samples \
|
|
-n 1
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
import argparse
|
|
|
|
import pandas as pd
|
|
import openai
|
|
|
|
|
|
def read_parquet(path):
|
|
return pd.read_parquet(path)
|
|
|
|
|
|
def save_to_parquet(df: pd.DataFrame, path: str):
|
|
"""Save DataFrame to parquet file"""
|
|
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
df.to_parquet(path)
|
|
print(f"Saved to {path}")
|
|
|
|
|
|
def print_stats(df: pd.DataFrame, name: str):
|
|
"""Print statistics of a DataFrame"""
|
|
print(f"\n{name} Statistics:")
|
|
print(f"- Total records: {len(df)}")
|
|
if "id" in df.columns:
|
|
print(f"- Unique ids: {df['id'].nunique()}")
|
|
if "qid" in df.columns:
|
|
print(f"- Unique qids: {df['qid'].nunique()}")
|
|
if "pid" in df.columns:
|
|
print(f"- Unique pids: {df['pid'].nunique()}")
|
|
|
|
|
|
def sample_data(
|
|
queries: pd.DataFrame, corpus: pd.DataFrame, qrels: pd.DataFrame, nq=1000
|
|
):
|
|
"""
|
|
Sample data from the dataset with validation checks.
|
|
|
|
Args:
|
|
queries: DataFrame with qid and text columns (one-to-one)
|
|
corpus: DataFrame with pid and text columns (one-to-one)
|
|
qrels: DataFrame with qid and pid columns (many-to-many)
|
|
nq: Number of queries to sample (default: 1000)
|
|
|
|
Returns:
|
|
Tuple of (sampled_queries, sampled_corpus, sampled_qrels)
|
|
"""
|
|
# 1. Filter qrels to only include qids that exist in queries
|
|
valid_qids = set(queries["id"])
|
|
qrels = qrels[qrels["qid"].isin(valid_qids)]
|
|
|
|
# 2. Filter qrels to only include pids that exist in corpus
|
|
valid_pids = set(corpus["id"])
|
|
qrels = qrels[qrels["pid"].isin(valid_pids)]
|
|
|
|
# 3. Sample queries (ensure we have enough qrels samples for each)
|
|
# Get qids with most associated pids to ensure diversity
|
|
qid_counts = qrels["qid"].value_counts()
|
|
sampled_qids = qid_counts.nlargest(min(nq, len(qid_counts))).index
|
|
|
|
# 4. Get all pids associated with sampled qids
|
|
sampled_qrels = qrels[qrels["qid"].isin(sampled_qids)]
|
|
sampled_pids = set(sampled_qrels["pid"])
|
|
|
|
# 5. Add extra pids from corpus for redundancy (20% of sampled pids)
|
|
extra_pids = set(corpus["id"].sample(int(0.2 * len(sampled_pids))))
|
|
all_pids = sampled_pids.union(extra_pids)
|
|
|
|
# 6. Create final sampled datasets
|
|
sampled_queries = queries[queries["id"].isin(sampled_qids)]
|
|
sampled_corpus = corpus[corpus["id"].isin(all_pids)]
|
|
|
|
return sampled_queries, sampled_corpus, sampled_qrels
|
|
|
|
|
|
class QAAnsweringSystem:
|
|
def __init__(
|
|
self, queries: pd.DataFrame, corpus: pd.DataFrame, qrels: pd.DataFrame
|
|
):
|
|
"""
|
|
Initialize QA system with data
|
|
|
|
Args:
|
|
queries: DataFrame with qid and text columns
|
|
corpus: DataFrame with pid and text columns
|
|
qrels: DataFrame with qid and pid mapping
|
|
"""
|
|
self.queries = queries
|
|
self.corpus = corpus
|
|
self.qrels = qrels
|
|
self.client = openai.Client(
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
base_url=os.getenv("OPENAI_BASE_URL"),
|
|
)
|
|
|
|
# Create lookup dictionaries
|
|
self.qid_to_text = dict(zip(queries["id"], queries["text"]))
|
|
self.pid_to_text = dict(zip(corpus["id"], corpus["text"]))
|
|
self.qid_to_pids = qrels.groupby("qid")["pid"].apply(list).to_dict()
|
|
|
|
def get_context_for_qid(self, qid: str) -> str:
|
|
"""
|
|
Get all relevant text for a query ID
|
|
|
|
Args:
|
|
qid: Query ID to search for
|
|
|
|
Returns:
|
|
Combined context text from all related passages
|
|
"""
|
|
if qid not in self.qid_to_pids:
|
|
raise ValueError("Question ID not found")
|
|
|
|
context_parts = []
|
|
print(f"Context for Question ID {qid}: {self.qid_to_pids[qid]}")
|
|
for pid in self.qid_to_pids[qid]:
|
|
if pid in self.pid_to_text:
|
|
context_parts.append(self.pid_to_text[pid])
|
|
|
|
return "\n\n".join(context_parts)
|
|
|
|
def answer_question(self, qid: str, model: str = "gpt-4o-2024-05-13") -> str:
|
|
"""
|
|
Use OpenAI API to answer question based on qid context
|
|
|
|
Args:
|
|
qid: Query ID to answer
|
|
model: OpenAI model to use
|
|
|
|
Returns:
|
|
Generated answer from LLM
|
|
"""
|
|
if qid not in self.qid_to_text:
|
|
raise ValueError("Question ID not found")
|
|
|
|
question = self.qid_to_text[qid]
|
|
context = self.get_context_for_qid(qid)
|
|
|
|
if not context:
|
|
raise ValueError("No context found for this question")
|
|
|
|
prompt = f"""Answer the question based on the context below. Keep the answer concise.
|
|
|
|
Question: {question}
|
|
|
|
Context: {context}
|
|
|
|
Answer:"""
|
|
response = self.client.chat.completions.create(
|
|
model=model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
temperature=0.3,
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
|
|
def sample_command(args):
|
|
"""Handle sample command"""
|
|
# Load data
|
|
print("Loading data...")
|
|
queries = read_parquet(args.queries)
|
|
corpus = read_parquet(args.corpus)
|
|
qrels = read_parquet(args.qrels)
|
|
|
|
# Print original stats
|
|
print("\nOriginal Dataset Statistics:")
|
|
print_stats(queries, "Queries")
|
|
print_stats(corpus, "Corpus")
|
|
print_stats(qrels, "Qrels")
|
|
|
|
# Sample data
|
|
print(f"\nSampling {args.nq} queries...")
|
|
sampled_queries, sampled_corpus, sampled_qrels = sample_data(
|
|
queries, corpus, qrels, args.nq
|
|
)
|
|
|
|
# Print sampled stats
|
|
print("\nSampled Dataset Statistics:")
|
|
print_stats(sampled_queries, "Sampled Queries")
|
|
print_stats(sampled_corpus, "Sampled Corpus")
|
|
print_stats(sampled_qrels, "Sampled Qrels")
|
|
|
|
# Save sampled data
|
|
print("\nSaving sampled data...")
|
|
save_to_parquet(sampled_queries, f"{args.output_dir}/queries.parquet")
|
|
save_to_parquet(sampled_corpus, f"{args.output_dir}/corpus.parquet")
|
|
save_to_parquet(sampled_qrels, f"{args.output_dir}/qrels.parquet")
|
|
print("\nSampling completed successfully!")
|
|
|
|
|
|
def generate_answers(input_dir: str, output_dir: str, max_retries: int = 3):
|
|
"""
|
|
Generate answers for sampled queries with resume support
|
|
|
|
Args:
|
|
input_dir: Directory containing sampled queries/corpus/qrels
|
|
output_dir: Directory to save answer files
|
|
max_retries: Maximum retry attempts for failed queries
|
|
"""
|
|
print("\nLoading sampled data...")
|
|
queries = read_parquet(f"{input_dir}/queries.parquet")
|
|
corpus = read_parquet(f"{input_dir}/corpus.parquet")
|
|
qrels = read_parquet(f"{input_dir}/qrels.parquet")
|
|
|
|
# Try to load existing answers if any
|
|
answers_path = f"{output_dir}/answers.parquet"
|
|
qa_pairs_path = f"{output_dir}/qas.parquet"
|
|
|
|
try:
|
|
existing_answers = read_parquet(answers_path)
|
|
existing_qas = read_parquet(qa_pairs_path)
|
|
processed_qids = set(existing_qas["qid"])
|
|
print(f"\nFound {len(processed_qids)} previously processed queries")
|
|
except (FileNotFoundError, KeyError):
|
|
print("No existing answers found, use empty state")
|
|
existing_answers = pd.DataFrame(columns=["id", "text"])
|
|
existing_qas = pd.DataFrame(columns=["qid", "aid"])
|
|
processed_qids = set()
|
|
|
|
qa_system = QAAnsweringSystem(queries, corpus, qrels)
|
|
|
|
answers = existing_answers.to_dict("records")
|
|
qa_pairs = existing_qas.to_dict("records")
|
|
answer_id_counter = len(answers) + 1
|
|
|
|
for qid in queries["id"]:
|
|
if qid in processed_qids:
|
|
continue
|
|
|
|
retry_count = 0
|
|
while retry_count <= max_retries:
|
|
try:
|
|
answer_text = qa_system.answer_question(qid)
|
|
aid = answer_id_counter
|
|
answers.append({"id": aid, "text": answer_text})
|
|
qa_pairs.append({"qid": qid, "aid": aid})
|
|
answer_id_counter += 1
|
|
|
|
# Save progress after each successful answer
|
|
save_to_parquet(pd.DataFrame(answers), answers_path)
|
|
save_to_parquet(pd.DataFrame(qa_pairs), qa_pairs_path)
|
|
print(f"Processed qid: {qid}")
|
|
break
|
|
except (openai.APIError, openai.APIConnectionError) as e:
|
|
retry_count += 1
|
|
if retry_count > max_retries:
|
|
print(
|
|
f"\nFailed to process qid {qid} after {max_retries} attempts: {str(e)}"
|
|
)
|
|
# Save failed state
|
|
save_to_parquet(pd.DataFrame(answers), answers_path)
|
|
save_to_parquet(pd.DataFrame(qa_pairs), qa_pairs_path)
|
|
else:
|
|
print(f"\nRetry {retry_count} for qid {qid}...")
|
|
|
|
print("\nAnswer generation completed!")
|
|
print(f"Total queries: {len(queries)}")
|
|
print(f"Successfully processed: {len(qa_pairs)}")
|
|
print(f"Failed queries: {len(queries) - len(qa_pairs)}")
|
|
|
|
|
|
def show_results(input_dir: str, n: int = 5):
|
|
"""
|
|
Show n random results with question, context and answer
|
|
|
|
Args:
|
|
input_dir: Directory containing the QA data
|
|
n: Number of results to show (default: 5)
|
|
"""
|
|
print(f"\nShowing {n} random results:")
|
|
|
|
# Load data
|
|
queries = read_parquet(f"{input_dir}/queries.parquet")
|
|
corpus = read_parquet(f"{input_dir}/corpus.parquet")
|
|
qrels = read_parquet(f"{input_dir}/qrels.parquet")
|
|
qa_pairs = read_parquet(f"{input_dir}/qas.parquet")
|
|
answers = read_parquet(f"{input_dir}/answers.parquet")
|
|
|
|
# Create QA system for context lookup
|
|
qa_system = QAAnsweringSystem(queries, corpus, qrels)
|
|
|
|
# Get first n QA pairs
|
|
for _, row in qa_pairs.sample(n).iterrows():
|
|
qid = row["qid"]
|
|
aid = row["aid"]
|
|
|
|
# Get question
|
|
question = qa_system.qid_to_text[qid]
|
|
|
|
# Get context
|
|
context = qa_system.get_context_for_qid(qid)
|
|
|
|
# Get answer
|
|
answer = answers[answers["id"] == aid]["text"].values[0]
|
|
|
|
print("\n" + "=" * 50)
|
|
print(f"Question (qid={qid}):\n{question}")
|
|
print("\nContext:")
|
|
print(context)
|
|
print(f"\nAnswer (aid={aid}):\n{answer}")
|
|
print("=" * 50 + "\n")
|
|
|
|
|
|
def main():
|
|
# Set up command line arguments
|
|
parser = argparse.ArgumentParser(description="QA Dataset Tool")
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# Sample command
|
|
sample_parser = subparsers.add_parser("sample", help="Sample dataset")
|
|
sample_parser.add_argument(
|
|
"--queries", type=str, required=True, help="Path to queries parquet file"
|
|
)
|
|
sample_parser.add_argument(
|
|
"--corpus", type=str, required=True, help="Path to corpus parquet file"
|
|
)
|
|
sample_parser.add_argument(
|
|
"--qrels", type=str, required=True, help="Path to qrels parquet file"
|
|
)
|
|
sample_parser.add_argument(
|
|
"--nq", type=int, default=1000, help="Number of queries to sample"
|
|
)
|
|
sample_parser.add_argument(
|
|
"--output_dir", type=str, default="./save", help="Output directory"
|
|
)
|
|
sample_parser.set_defaults(func=sample_command)
|
|
|
|
# Generate command
|
|
generate_parser = subparsers.add_parser("generate", help="Generate answers")
|
|
generate_parser.add_argument(
|
|
"--input_dir", type=str, required=True, help="Directory with sampled data"
|
|
)
|
|
generate_parser.add_argument(
|
|
"--output_dir", type=str, default="./save", help="Output directory"
|
|
)
|
|
generate_parser.set_defaults(
|
|
func=lambda args: generate_answers(args.input_dir, args.output_dir)
|
|
)
|
|
|
|
# Show command
|
|
show_parser = subparsers.add_parser("show", help="Show QA results")
|
|
show_parser.add_argument(
|
|
"--input_dir", type=str, required=True, help="Directory with QA data"
|
|
)
|
|
show_parser.add_argument(
|
|
"-n", type=int, default=5, help="Number of results to show (default: 5)"
|
|
)
|
|
show_parser.set_defaults(func=lambda args: show_results(args.input_dir, args.n))
|
|
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|