You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
3.2KB

  1. package utils
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha1"
  5. "fmt"
  6. "github.com/google/uuid"
  7. "github.com/tal-tech/go-zero/core/logx"
  8. "io/ioutil"
  9. "job_risk_third/model"
  10. "mime"
  11. "net/http"
  12. "net/url"
  13. "sort"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. var (
  19. ErrorBadRequest = fmt.Errorf("无效的请求参数")
  20. )
  21. func Validation(r *http.Request) error {
  22. if err := ParseHttpParams(r); err != nil {
  23. return err
  24. }
  25. //vars := mux.Vars(r)
  26. timestamp, err := strconv.ParseInt(Param(r, "timestamp"), 10, 64)
  27. if err != nil {
  28. return ErrorBadRequest
  29. }
  30. now := time.Now().UnixNano() / 1e6
  31. if timestamp > now+100000 || now >= timestamp+6000000 {
  32. return fmt.Errorf("请求过期")
  33. }
  34. sign := Param(r, "sign")
  35. if len(sign) == 0 {
  36. return ErrorBadRequest
  37. }
  38. nonce := Param(r, "nonce")
  39. if len(nonce) == 0 {
  40. return ErrorBadRequest
  41. }
  42. Params := r.Form
  43. var keys []string
  44. for k := range Params {
  45. keys = append(keys, k)
  46. }
  47. sort.Strings(keys)
  48. str := ""
  49. for _, v := range keys {
  50. if v != "sign" && v != "nonce" && v != "__json_param__" {
  51. udata := url.Values{}
  52. udata.Set(v, Params.Get(v))
  53. strUrl := udata.Encode()
  54. str += "&" + strUrl
  55. }
  56. if v == "__json_param__" {
  57. udata := url.Values{}
  58. udata.Set(Params.Get(v), "")
  59. str += "&" + udata.Encode()
  60. }
  61. }
  62. if len(str) > 1 && str[len(str)-1:] == "=" {
  63. str = r.Method + str[1:len(str)-1]
  64. }
  65. logx.Infof("str : %s", str)
  66. our := Sha1(str, nonce)
  67. if sign != strings.Trim(our, " ") {
  68. logx.Infof("our:[%s] sign: [%s] ", our, sign)
  69. return fmt.Errorf("错误的签名")
  70. }
  71. logx.Info(r.RemoteAddr[:9])
  72. if r.RemoteAddr[:9] != "127.0.0.1" && r.RemoteAddr[:9] != "localhost" {
  73. //判断是否重放请求
  74. hExists := model.Redis.Exists(nonce).Val()
  75. if hExists > 0 {
  76. return fmt.Errorf("请勿重复请求")
  77. }
  78. //防止业务数据重复提交
  79. hExists = model.Redis.Exists(sign).Val()
  80. if hExists > 0 {
  81. return fmt.Errorf("请勿重复提交数据")
  82. }
  83. model.Redis.Set(nonce, timestamp, time.Minute*5)
  84. model.Redis.Set(sign, timestamp, time.Minute*5)
  85. }
  86. return nil
  87. }
  88. func ParseHttpParams(r *http.Request) (err error) {
  89. ct := r.Header.Get("Content-Type")
  90. if ct == "" {
  91. ct = "application/octet-stream"
  92. }
  93. ct, _, err = mime.ParseMediaType(ct)
  94. switch {
  95. case ct == "application/json":
  96. result, err := ioutil.ReadAll(r.Body)
  97. if err != nil {
  98. logx.Errorf("ParseHttpParams error info : %v", err)
  99. return err
  100. }
  101. if r.Form == nil {
  102. r.Form = url.Values{}
  103. }
  104. r.Form.Set("__json_param__", string(result))
  105. return nil
  106. }
  107. return err
  108. }
  109. //sha1加签
  110. func Sha1(query string, priKey string) string {
  111. key := []byte(priKey)
  112. mac := hmac.New(sha1.New, key)
  113. mac.Write([]byte(query))
  114. //query = base64.StdEncoding.EncodeToString(mac.Sum(nil))
  115. //query = url.QueryEscape(query)
  116. query = fmt.Sprintf("%x", mac.Sum(nil))
  117. return query
  118. }
  119. func Param(r *http.Request, key string) string {
  120. var value string
  121. value = r.FormValue(key)
  122. if len(value) > 0 {
  123. return value
  124. }
  125. value = r.URL.Query().Get(key)
  126. if len(value) > 0 {
  127. return value
  128. }
  129. value = r.Header.Get(key)
  130. if len(value) > 0 {
  131. return value
  132. }
  133. if cookie, _ := r.Cookie(key); cookie != nil {
  134. return cookie.Value
  135. }
  136. return value
  137. }
  138. func NewId() string {
  139. id := uuid.New().String()
  140. id = strings.Replace(id, "-", "", -1)
  141. return id
  142. }