diff --git a/const.go b/const.go index 5435bf9..17e7fe9 100644 --- a/const.go +++ b/const.go @@ -35,3 +35,17 @@ var smsBusinessWithRequestPath = map[SmsBusiness]RequestPath{ SmsBusinessHs: sendSmsHs, SmsBusinessDefault: sendSms, } + +type ( + SmsOption func(*SmsOptionData) + SmsOptionData struct { + Business SmsBusiness + } +) + +func WithBusiness(business SmsBusiness) SmsOption { + + return func(OptionData *SmsOptionData) { + OptionData.Business = business + } +} diff --git a/msg.go b/msg.go index 9f7c08b..be62a1e 100644 --- a/msg.go +++ b/msg.go @@ -58,19 +58,27 @@ func (m *MessageCenter) OAGetDetail(outTradeNo string) (data OAGetDetailData, er // SendSms 短信 // business SmsBusiness -func (m *MessageCenter) SendSms(tels []string, jsonParam string, business SmsBusiness) (data SmsSend, err error) { +func (m *MessageCenter) SendSms(tels []string, jsonParam string, args ...SmsOption) (data SmsSend, err error) { var ( - path RequestPath - ex bool + e = new(SmsOptionData) ) if len(tels) == 0 { err = errors.New("手机号不能为空") return } - if path, ex = smsBusinessWithRequestPath[business]; !ex { - err = errors.New("未知的供应商") - return + for _, arg := range args { + arg(e) } + if e.Business != "" { + if _, ex := smsBusinessWithRequestPath[e.Business]; !ex { + err = errors.New("business参数错误") + return + } + } else { + e.Business = SmsBusinessDefault + } + path := smsBusinessWithRequestPath[e.Business] + param := m.parseSmsSendParam(tels, jsonParam) err = m.post(path, param, &data) if err != nil {