You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
5.8KB

  1. /*
  2. * Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved.
  3. *
  4. * This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit).
  5. *
  6. * Use of this source code is governed by MIT license that can be found in the
  7. * LICENSE file in the root of the source tree. All contributing project authors
  8. * may be found in the AUTHORS file in the root of the source tree.
  9. */
  10. #ifndef CRYPTO_SSLBOX_H_
  11. #define CRYPTO_SSLBOX_H_
  12. #include <mutex>
  13. #include <string>
  14. #include <functional>
  15. #include "logger.h"
  16. #include "List.h"
  17. #include "util.h"
  18. #include "Network/Buffer.h"
  19. #include "ResourcePool.h"
  20. typedef struct x509_st X509;
  21. typedef struct evp_pkey_st EVP_PKEY;
  22. typedef struct ssl_ctx_st SSL_CTX;
  23. typedef struct ssl_st SSL;
  24. typedef struct bio_st BIO;
  25. namespace toolkit {
  26. class SSL_Initor {
  27. public:
  28. friend class SSL_Box;
  29. static SSL_Initor &Instance();
  30. /**
  31. * 从文件或字符串中加载公钥和私钥
  32. * 该证书文件必须同时包含公钥和私钥(cer格式的证书只包括公钥,请使用后面的方法加载)
  33. * 客户端默认可以不加载证书(除非服务器要求客户端提供证书)
  34. * @param pem_or_p12 pem或p12文件路径或者文件内容字符串
  35. * @param server_mode 是否为服务器模式
  36. * @param password 私钥加密密码
  37. * @param is_file 参数pem_or_p12是否为文件路径
  38. * @param is_default 是否为默认证书
  39. */
  40. bool loadCertificate(const std::string &pem_or_p12, bool server_mode = true, const std::string &password = "",
  41. bool is_file = true, bool is_default = true);
  42. /**
  43. * 是否忽略无效的证书
  44. * 默认忽略,强烈建议不要忽略!
  45. * @param ignore 标记
  46. */
  47. void ignoreInvalidCertificate(bool ignore = true);
  48. /**
  49. * 信任某证书,一般用于客户端信任自签名的证书或自签名CA签署的证书使用
  50. * 比如说我的客户端要信任我自己签发的证书,那么我们可以只信任这个证书
  51. * @param pem_p12_cer pem文件或p12文件或cer文件路径或内容
  52. * @param server_mode 是否为服务器模式
  53. * @param password pem或p12证书的密码
  54. * @param is_file 是否为文件路径
  55. * @return 是否加载成功
  56. */
  57. bool trustCertificate(const std::string &pem_p12_cer, bool server_mode = false, const std::string &password = "",
  58. bool is_file = true);
  59. /**
  60. * 信任某证书
  61. * @param cer 证书公钥
  62. * @param server_mode 是否为服务模式
  63. * @return 是否加载成功
  64. */
  65. bool trustCertificate(X509 *cer, bool server_mode = false);
  66. private:
  67. SSL_Initor();
  68. ~SSL_Initor();
  69. /**
  70. * 创建SSL对象
  71. */
  72. std::shared_ptr<SSL> makeSSL(bool server_mode);
  73. /**
  74. * 设置ssl context
  75. * @param vhost 虚拟主机名
  76. * @param ctx ssl context
  77. * @param server_mode ssl context
  78. * @param is_default 是否为默认证书
  79. */
  80. bool setContext(const std::string &vhost, const std::shared_ptr<SSL_CTX> &ctx, bool server_mode, bool is_default = true);
  81. /**
  82. * 设置SSL_CTX的默认配置
  83. * @param ctx 对象指针
  84. */
  85. void setupCtx(SSL_CTX *ctx);
  86. /**
  87. * 根据虚拟主机获取SSL_CTX对象
  88. * @param vhost 虚拟主机名
  89. * @param server_mode 是否为服务器模式
  90. * @return SSL_CTX对象
  91. */
  92. std::shared_ptr<SSL_CTX> getSSLCtx(const std::string &vhost, bool server_mode);
  93. std::shared_ptr<SSL_CTX> getSSLCtx_l(const std::string &vhost, bool server_mode);
  94. std::shared_ptr<SSL_CTX> getSSLCtxWildcards(const std::string &vhost, bool server_mode);
  95. /**
  96. * 获取默认的虚拟主机
  97. */
  98. std::string defaultVhost(bool server_mode);
  99. /**
  100. * 完成vhost name 匹配的回调函数
  101. */
  102. static int findCertificate(SSL *ssl, int *ad, void *arg);
  103. private:
  104. struct less_nocase {
  105. bool operator()(const std::string &x, const std::string &y) const {
  106. return strcasecmp(x.data(), y.data()) < 0;
  107. }
  108. };
  109. private:
  110. std::string _default_vhost[2];
  111. std::shared_ptr<SSL_CTX> _ctx_empty[2];
  112. std::map<std::string, std::shared_ptr<SSL_CTX>, less_nocase> _ctxs[2];
  113. std::map<std::string, std::shared_ptr<SSL_CTX>, less_nocase> _ctxs_wildcards[2];
  114. };
  115. ////////////////////////////////////////////////////////////////////////////////////
  116. class SSL_Box {
  117. public:
  118. SSL_Box(bool server_mode = true, bool enable = true, int buff_size = 32 * 1024);
  119. ~SSL_Box();
  120. /**
  121. * 收到密文后,调用此函数解密
  122. * @param buffer 收到的密文数据
  123. */
  124. void onRecv(const Buffer::Ptr &buffer);
  125. /**
  126. * 需要加密明文调用此函数
  127. * @param buffer 需要加密的明文数据
  128. */
  129. void onSend(Buffer::Ptr buffer);
  130. /**
  131. * 设置解密后获取明文的回调
  132. * @param cb 回调对象
  133. */
  134. void setOnDecData(const std::function<void(const Buffer::Ptr &)> &cb);
  135. /**
  136. * 设置加密后获取密文的回调
  137. * @param cb 回调对象
  138. */
  139. void setOnEncData(const std::function<void(const Buffer::Ptr &)> &cb);
  140. /**
  141. * 终结ssl
  142. */
  143. void shutdown();
  144. /**
  145. * 清空数据
  146. */
  147. void flush();
  148. /**
  149. * 设置虚拟主机名
  150. * @param host 虚拟主机名
  151. * @return 是否成功
  152. */
  153. bool setHost(const char *host);
  154. private:
  155. void flushWriteBio();
  156. void flushReadBio();
  157. private:
  158. bool _server_mode;
  159. bool _send_handshake;
  160. bool _is_flush = false;
  161. int _buff_size;
  162. BIO *_read_bio;
  163. BIO *_write_bio;
  164. std::shared_ptr<SSL> _ssl;
  165. List <Buffer::Ptr> _buffer_send;
  166. ResourcePool <BufferRaw> _buffer_pool;
  167. std::function<void(const Buffer::Ptr &)> _on_dec;
  168. std::function<void(const Buffer::Ptr &)> _on_enc;
  169. };
  170. } /* namespace toolkit */
  171. #endif /* CRYPTO_SSLBOX_H_ */