package repository import ( "context" "slices" "time" "gorm.io/gorm" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" ) // messageRepository implements the message repository interface type messageRepository struct { db *gorm.DB } // NewMessageRepository creates a new message repository func NewMessageRepository(db *gorm.DB) interfaces.MessageRepository { return &messageRepository{ db: db, } } // CreateMessage creates a new message func (r *messageRepository) CreateMessage( ctx context.Context, message *types.Message, ) (*types.Message, error) { if err := r.db.WithContext(ctx).Create(message).Error; err != nil { return nil, err } return message, nil } // GetMessage retrieves a message func (r *messageRepository) GetMessage( ctx context.Context, sessionID string, messageID string, ) (*types.Message, error) { var message types.Message if err := r.db.WithContext(ctx).Where( "id = ? AND session_id = ?", messageID, sessionID, ).First(&message).Error; err != nil { return nil, err } return &message, nil } // GetMessagesBySession retrieves all messages for a session with pagination func (r *messageRepository) GetMessagesBySession( ctx context.Context, sessionID string, page int, pageSize int, ) ([]*types.Message, error) { var messages []*types.Message if err := r.db.WithContext(ctx).Where("session_id = ?", sessionID).Order("created_at ASC"). Offset((page - 1) * pageSize).Limit(pageSize).Find(&messages).Error; err != nil { return nil, err } return messages, nil } // GetRecentMessagesBySession retrieves recent messages for a session func (r *messageRepository) GetRecentMessagesBySession( ctx context.Context, sessionID string, limit int, ) ([]*types.Message, error) { var messages []*types.Message if err := r.db.WithContext(ctx).Where( "session_id = ?", sessionID, ).Order("created_at DESC").Limit(limit).Find(&messages).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, err } slices.SortFunc(messages, func(a, b *types.Message) int { cmp := a.CreatedAt.Compare(b.CreatedAt) if cmp == 0 { if a.Role == "user" { // User messages come first return -1 } return 1 // Assistant messages come last } return cmp }) return messages, nil } // GetMessagesBySessionBeforeTime retrieves messages from a session created before a specific time func (r *messageRepository) GetMessagesBySessionBeforeTime( ctx context.Context, sessionID string, beforeTime time.Time, limit int, ) ([]*types.Message, error) { var messages []*types.Message if err := r.db.WithContext(ctx).Where( "session_id = ? AND created_at < ?", sessionID, beforeTime, ).Order("created_at DESC").Limit(limit).Find(&messages).Error; err != nil { return nil, err } slices.SortFunc(messages, func(a, b *types.Message) int { cmp := a.CreatedAt.Compare(b.CreatedAt) if cmp == 0 { if a.Role == "user" { // User messages come first return -1 } return 1 // Assistant messages come last } return cmp }) return messages, nil } // UpdateMessage updates an existing message func (r *messageRepository) UpdateMessage(ctx context.Context, message *types.Message) error { return r.db.WithContext(ctx).Model(&types.Message{}).Where( "id = ? AND session_id = ?", message.ID, message.SessionID, ).Updates(message).Error } // DeleteMessage deletes a message func (r *messageRepository) DeleteMessage(ctx context.Context, sessionID string, messageID string) error { return r.db.WithContext(ctx).Where( "id = ? AND session_id = ?", messageID, sessionID, ).Delete(&types.Message{}).Error } // GetFirstMessageOfUser retrieves the first message from a user in a session func (r *messageRepository) GetFirstMessageOfUser(ctx context.Context, sessionID string) (*types.Message, error) { var message types.Message if err := r.db.WithContext(ctx).Where( "session_id = ? and role = ?", sessionID, "user", ).Order("created_at ASC").First(&message).Error; err != nil { return nil, err } return &message, nil } // GetMessageByRequestID retrieves a message by request ID func (r *messageRepository) GetMessageByRequestID( ctx context.Context, sessionID string, requestID string, ) (*types.Message, error) { var message types.Message result := r.db.WithContext(ctx). Where("session_id = ? AND request_id = ?", sessionID, requestID). First(&message) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil } return nil, result.Error } return &message, nil }