package utils import ( "crypto/hmac" "crypto/sha1" "fmt" "github.com/google/uuid" "github.com/tal-tech/go-zero/core/logx" "io/ioutil" "job_risk_third/model" "mime" "net/http" "net/url" "sort" "strconv" "strings" "time" ) var ( ErrorBadRequest = fmt.Errorf("无效的请求参数") ) func Validation(r *http.Request) error { if err := ParseHttpParams(r); err != nil { return err } //vars := mux.Vars(r) timestamp, err := strconv.ParseInt(Param(r, "timestamp"), 10, 64) if err != nil { return ErrorBadRequest } now := time.Now().UnixNano() / 1e6 if timestamp > now+100000 || now >= timestamp+6000000 { return fmt.Errorf("请求过期") } sign := Param(r, "sign") if len(sign) == 0 { return ErrorBadRequest } nonce := Param(r, "nonce") if len(nonce) == 0 { return ErrorBadRequest } Params := r.Form var keys []string for k := range Params { keys = append(keys, k) } sort.Strings(keys) str := "" for _, v := range keys { if v != "sign" && v != "nonce" && v != "__json_param__" { udata := url.Values{} udata.Set(v, Params.Get(v)) strUrl := udata.Encode() str += "&" + strUrl } if v == "__json_param__" { udata := url.Values{} udata.Set(Params.Get(v), "") str += "&" + udata.Encode() } } if len(str) > 1 && str[len(str)-1:] == "=" { str = r.Method + str[1:len(str)-1] } logx.Infof("str : %s", str) our := Sha1(str, nonce) if sign != strings.Trim(our, " ") { logx.Infof("our:[%s] sign: [%s] ", our, sign) return fmt.Errorf("错误的签名") } logx.Info(r.RemoteAddr[:9]) if r.RemoteAddr[:9] != "127.0.0.1" && r.RemoteAddr[:9] != "localhost" { //判断是否重放请求 hExists := model.Redis.Exists(nonce).Val() if hExists > 0 { return fmt.Errorf("请勿重复请求") } //防止业务数据重复提交 hExists = model.Redis.Exists(sign).Val() if hExists > 0 { return fmt.Errorf("请勿重复提交数据") } model.Redis.Set(nonce, timestamp, time.Minute*5) model.Redis.Set(sign, timestamp, time.Minute*5) } return nil } func ParseHttpParams(r *http.Request) (err error) { ct := r.Header.Get("Content-Type") if ct == "" { ct = "application/octet-stream" } ct, _, err = mime.ParseMediaType(ct) switch { case ct == "application/json": result, err := ioutil.ReadAll(r.Body) if err != nil { logx.Errorf("ParseHttpParams error info : %v", err) return err } if r.Form == nil { r.Form = url.Values{} } r.Form.Set("__json_param__", string(result)) return nil } return err } //sha1加签 func Sha1(query string, priKey string) string { key := []byte(priKey) mac := hmac.New(sha1.New, key) mac.Write([]byte(query)) //query = base64.StdEncoding.EncodeToString(mac.Sum(nil)) //query = url.QueryEscape(query) query = fmt.Sprintf("%x", mac.Sum(nil)) return query } func Param(r *http.Request, key string) string { var value string value = r.FormValue(key) if len(value) > 0 { return value } value = r.URL.Query().Get(key) if len(value) > 0 { return value } value = r.Header.Get(key) if len(value) > 0 { return value } if cookie, _ := r.Cookie(key); cookie != nil { return cookie.Value } return value } func NewId() string { id := uuid.New().String() id = strings.Replace(id, "-", "", -1) return id }