l_ai_knowledge/internal/application/repository/message.go

154 lines
4.4 KiB
Go

package repository
import (
"context"
"slices"
"time"
"gorm.io/gorm"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
// messageRepository implements the message repository interface
type messageRepository struct {
db *gorm.DB
}
// NewMessageRepository creates a new message repository
func NewMessageRepository(db *gorm.DB) interfaces.MessageRepository {
return &messageRepository{
db: db,
}
}
// CreateMessage creates a new message
func (r *messageRepository) CreateMessage(
ctx context.Context, message *types.Message,
) (*types.Message, error) {
if err := r.db.WithContext(ctx).Create(message).Error; err != nil {
return nil, err
}
return message, nil
}
// GetMessage retrieves a message
func (r *messageRepository) GetMessage(
ctx context.Context, sessionID string, messageID string,
) (*types.Message, error) {
var message types.Message
if err := r.db.WithContext(ctx).Where(
"id = ? AND session_id = ?", messageID, sessionID,
).First(&message).Error; err != nil {
return nil, err
}
return &message, nil
}
// GetMessagesBySession retrieves all messages for a session with pagination
func (r *messageRepository) GetMessagesBySession(
ctx context.Context, sessionID string, page int, pageSize int,
) ([]*types.Message, error) {
var messages []*types.Message
if err := r.db.WithContext(ctx).Where("session_id = ?", sessionID).Order("created_at ASC").
Offset((page - 1) * pageSize).Limit(pageSize).Find(&messages).Error; err != nil {
return nil, err
}
return messages, nil
}
// GetRecentMessagesBySession retrieves recent messages for a session
func (r *messageRepository) GetRecentMessagesBySession(
ctx context.Context, sessionID string, limit int,
) ([]*types.Message, error) {
var messages []*types.Message
if err := r.db.WithContext(ctx).Where(
"session_id = ?", sessionID,
).Order("created_at DESC").Limit(limit).Find(&messages).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
slices.SortFunc(messages, func(a, b *types.Message) int {
cmp := a.CreatedAt.Compare(b.CreatedAt)
if cmp == 0 {
if a.Role == "user" { // User messages come first
return -1
}
return 1 // Assistant messages come last
}
return cmp
})
return messages, nil
}
// GetMessagesBySessionBeforeTime retrieves messages from a session created before a specific time
func (r *messageRepository) GetMessagesBySessionBeforeTime(
ctx context.Context, sessionID string, beforeTime time.Time, limit int,
) ([]*types.Message, error) {
var messages []*types.Message
if err := r.db.WithContext(ctx).Where(
"session_id = ? AND created_at < ?", sessionID, beforeTime,
).Order("created_at DESC").Limit(limit).Find(&messages).Error; err != nil {
return nil, err
}
slices.SortFunc(messages, func(a, b *types.Message) int {
cmp := a.CreatedAt.Compare(b.CreatedAt)
if cmp == 0 {
if a.Role == "user" { // User messages come first
return -1
}
return 1 // Assistant messages come last
}
return cmp
})
return messages, nil
}
// UpdateMessage updates an existing message
func (r *messageRepository) UpdateMessage(ctx context.Context, message *types.Message) error {
return r.db.WithContext(ctx).Model(&types.Message{}).Where(
"id = ? AND session_id = ?", message.ID, message.SessionID,
).Updates(message).Error
}
// DeleteMessage deletes a message
func (r *messageRepository) DeleteMessage(ctx context.Context, sessionID string, messageID string) error {
return r.db.WithContext(ctx).Where(
"id = ? AND session_id = ?", messageID, sessionID,
).Delete(&types.Message{}).Error
}
// GetFirstMessageOfUser retrieves the first message from a user in a session
func (r *messageRepository) GetFirstMessageOfUser(ctx context.Context, sessionID string) (*types.Message, error) {
var message types.Message
if err := r.db.WithContext(ctx).Where(
"session_id = ? and role = ?", sessionID, "user",
).Order("created_at ASC").First(&message).Error; err != nil {
return nil, err
}
return &message, nil
}
// GetMessageByRequestID retrieves a message by request ID
func (r *messageRepository) GetMessageByRequestID(
ctx context.Context, sessionID string, requestID string,
) (*types.Message, error) {
var message types.Message
result := r.db.WithContext(ctx).
Where("session_id = ? AND request_id = ?", sessionID, requestID).
First(&message)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, result.Error
}
return &message, nil
}