server.hpp 51 KB


  1. // Copyright Takatoshi Kondo 2017
  2. //
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // (See accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. #if !defined(MQTT_SERVER_HPP)
  7. #define MQTT_SERVER_HPP
  8. #include <mqtt/variant.hpp> // should be top to configure variant limit
  9. #include <memory>
  10. #include <boost/asio.hpp>
  11. #include <mqtt/namespace.hpp>
  12. #include <mqtt/tcp_endpoint.hpp>
  13. #include <mqtt/endpoint.hpp>
  14. #include <mqtt/move.hpp>
  15. #include <mqtt/callable_overlay.hpp>
  16. #include <mqtt/strand.hpp>
  17. #include <mqtt/null_strand.hpp>
  18. namespace MQTT_NS {
  19. namespace as = boost::asio;
  20. template <typename Mutex, template<typename...> class LockGuard, std::size_t PacketIdBytes>
  21. class server_endpoint : public endpoint<Mutex, LockGuard, PacketIdBytes> {
  22. public:
  23. using endpoint<Mutex, LockGuard, PacketIdBytes>::endpoint;
  24. protected:
  25. void on_pre_send() noexcept override {}
  26. void on_close() noexcept override {}
  27. void on_error(error_code /*ec*/) noexcept override {}
  28. protected:
  29. ~server_endpoint() = default;
  30. };
  31. template <
  32. typename Strand = strand,
  33. typename Mutex = std::mutex,
  34. template<typename...> class LockGuard = std::lock_guard,
  35. std::size_t PacketIdBytes = 2
  36. >
  37. class server {
  38. public:
  39. using socket_t = tcp_endpoint<as::ip::tcp::socket, Strand>;
  40. using endpoint_t = callable_overlay<server_endpoint<Mutex, LockGuard, PacketIdBytes>>;
  41. /**
  42. * @brief Accept handler
  43. * After this handler called, the next accept will automatically start.
  44. * @param ep endpoint of the connecting client
  45. */
  46. using accept_handler = std::function<void(std::shared_ptr<endpoint_t> ep)>;
  47. /**
  48. * @brief Error handler during after accepted before connection established
  49. * After this handler called, the next accept will automatically start.
  50. * @param ec error code
  51. * @param ioc_con io_context for incoming connection
  52. */
  53. using connection_error_handler = std::function<void(error_code ec, as::io_context& ioc_con)>;
  54. /**
  55. * @brief Error handler for listen and accpet
  56. * After this handler called, the next accept won't start
  57. * You need to call listen() again if you want to restart accepting.
  58. * @param ec error code
  59. */
  60. using error_handler = std::function<void(error_code ec)>;
  61. /**
  62. * @brief Error handler for listen and accpet
  63. * After this handler called, the next accept won't start
  64. * You need to call listen() again if you want to restart accepting.
  65. * @param ec error code
  66. * @param ioc_con io_context for listen or accept
  67. */
  68. using error_handler_with_ioc = std::function<void(error_code ec, as::io_context& ioc_accept)>;
  69. template <typename AsioEndpoint, typename AcceptorConfig>
  70. server(
  71. AsioEndpoint&& ep,
  72. as::io_context& ioc_accept,
  73. as::io_context& ioc_con,
  74. AcceptorConfig&& config)
  75. : ep_(std::forward<AsioEndpoint>(ep)),
  76. ioc_accept_(ioc_accept),
  77. ioc_con_(&ioc_con),
  78. ioc_con_getter_([this]() -> as::io_context& { return *ioc_con_; }),
  79. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  80. config_(std::forward<AcceptorConfig>(config)) {
  81. config_(acceptor_.value());
  82. }
  83. template <typename AsioEndpoint>
  84. server(
  85. AsioEndpoint&& ep,
  86. as::io_context& ioc_accept,
  87. as::io_context& ioc_con)
  88. : server(std::forward<AsioEndpoint>(ep), ioc_accept, ioc_con, [](as::ip::tcp::acceptor&) {}) {}
  89. template <typename AsioEndpoint, typename AcceptorConfig>
  90. server(
  91. AsioEndpoint&& ep,
  92. as::io_context& ioc,
  93. AcceptorConfig&& config)
  94. : server(std::forward<AsioEndpoint>(ep), ioc, ioc, std::forward<AcceptorConfig>(config)) {}
  95. template <typename AsioEndpoint>
  96. server(
  97. AsioEndpoint&& ep,
  98. as::io_context& ioc)
  99. : server(std::forward<AsioEndpoint>(ep), ioc, ioc, [](as::ip::tcp::acceptor&) {}) {}
  100. template <typename AsioEndpoint, typename AcceptorConfig>
  101. server(
  102. AsioEndpoint&& ep,
  103. as::io_context& ioc_accept,
  104. std::function<as::io_context&()> ioc_con_getter,
  105. AcceptorConfig&& config = [](as::ip::tcp::acceptor&) {})
  106. : ep_(std::forward<AsioEndpoint>(ep)),
  107. ioc_accept_(ioc_accept),
  108. ioc_con_getter_(force_move(ioc_con_getter)),
  109. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  110. config_(std::forward<AcceptorConfig>(config)) {
  111. config_(acceptor_.value());
  112. }
  113. void listen() {
  114. close_request_ = false;
  115. if (!acceptor_) {
  116. try {
  117. acceptor_.emplace(ioc_accept_, ep_);
  118. config_(acceptor_.value());
  119. }
  120. catch (boost::system::system_error const& e) {
  121. as::post(
  122. ioc_accept_,
  123. [this, ec = e.code()] {
  124. if (h_error_) h_error_(ec, ioc_accept_);
  125. }
  126. );
  127. return;
  128. }
  129. }
  130. do_accept();
  131. }
  132. unsigned short port() const { return acceptor_.value().local_endpoint().port(); }
  133. void close() {
  134. close_request_ = true;
  135. as::post(
  136. ioc_accept_,
  137. [this] {
  138. acceptor_.reset();
  139. }
  140. );
  141. }
  142. void set_accept_handler(accept_handler h = accept_handler()) {
  143. h_accept_ = force_move(h);
  144. }
  145. /**
  146. * @brief Set error handler for listen and accept
  147. * @param h handler
  148. */
  149. void set_error_handler(error_handler h) {
  150. h_error_ =
  151. [h = force_move(h)]
  152. (error_code ec, as::io_context&) {
  153. if (h) h(ec);
  154. };
  155. }
  156. /**
  157. * @brief Set error handler for listen and accept
  158. * @param h handler
  159. */
  160. void set_error_handler(error_handler_with_ioc h = error_handler_with_ioc()) {
  161. h_error_ = force_move(h);
  162. }
  163. /**
  164. * @brief Set error handler
  165. * @param h handler
  166. */
  167. void set_connection_error_handler(connection_error_handler h = connection_error_handler()) {
  168. h_connection_error_ = force_move(h);
  169. }
  170. /**
  171. * @brief Set MQTT protocol version
  172. * @param version accepting protocol version
  173. * If the specific version is set, only set version is accepted.
  174. * If the version is set to protocol_version::undetermined, all versions are accepted.
  175. * Initial value is protocol_version::undetermined.
  176. */
  177. void set_protocol_version(protocol_version version) {
  178. version_ = version;
  179. }
  180. private:
  181. void do_accept() {
  182. if (close_request_) return;
  183. auto& ioc_con = ioc_con_getter_();
  184. auto socket = std::make_shared<socket_t>(ioc_con);
  185. acceptor_.value().async_accept(
  186. socket->lowest_layer(),
  187. [this, socket, &ioc_con]
  188. (error_code ec) mutable {
  189. if (ec) {
  190. acceptor_.reset();
  191. if (h_error_) h_error_(ec, ioc_con);
  192. return;
  193. }
  194. auto sp = std::make_shared<endpoint_t>(ioc_con, force_move(socket), version_);
  195. if (h_accept_) h_accept_(force_move(sp));
  196. do_accept();
  197. }
  198. );
  199. }
  200. private:
  201. as::ip::tcp::endpoint ep_;
  202. as::io_context& ioc_accept_;
  203. as::io_context* ioc_con_ = nullptr;
  204. std::function<as::io_context&()> ioc_con_getter_;
  205. optional<as::ip::tcp::acceptor> acceptor_;
  206. std::function<void(as::ip::tcp::acceptor&)> config_;
  207. bool close_request_{false};
  208. accept_handler h_accept_;
  209. connection_error_handler h_connection_error_;
  210. error_handler_with_ioc h_error_;
  211. protocol_version version_ = protocol_version::undetermined;
  212. };
  213. #if defined(MQTT_USE_TLS)
  214. template <
  215. typename Strand = strand,
  216. typename Mutex = std::mutex,
  217. template<typename...> class LockGuard = std::lock_guard,
  218. std::size_t PacketIdBytes = 2
  219. >
  220. class server_tls {
  221. public:
  222. using socket_t = tcp_endpoint<tls::stream<as::ip::tcp::socket>, Strand>;
  223. using endpoint_t = callable_overlay<server_endpoint<Mutex, LockGuard, PacketIdBytes>>;
  224. /**
  225. * @brief Accept handler
  226. * After this handler called, the next accept will automatically start.
  227. * @param ep endpoint of the connecting client
  228. */
  229. using accept_handler = std::function<void(std::shared_ptr<endpoint_t> ep)>;
  230. /**
  231. * @brief Error handler during after accepted before connection established
  232. * After this handler called, the next accept will automatically start.
  233. * @param ec error code
  234. * @param ioc_con io_context for incoming connection
  235. */
  236. using connection_error_handler = std::function<void(error_code ec, as::io_context& ioc_con)>;
  237. /**
  238. * @brief Error handler for listen and accpet
  239. * After this handler called, the next accept won't start
  240. * You need to call listen() again if you want to restart accepting.
  241. * @param ec error code
  242. */
  243. using error_handler = std::function<void(error_code ec)>;
  244. /**
  245. * @brief Error handler for listen and accpet
  246. * After this handler called, the next accept won't start
  247. * You need to call listen() again if you want to restart accepting.
  248. * @param ec error code
  249. * @param ioc_con io_context for listen or accept
  250. */
  251. using error_handler_with_ioc = std::function<void(error_code ec, as::io_context& ioc_accept)>;
  252. template <typename AsioEndpoint, typename AcceptorConfig>
  253. server_tls(
  254. AsioEndpoint&& ep,
  255. tls::context&& ctx,
  256. as::io_context& ioc_accept,
  257. as::io_context& ioc_con,
  258. AcceptorConfig&& config)
  259. : ep_(std::forward<AsioEndpoint>(ep)),
  260. ioc_accept_(ioc_accept),
  261. ioc_con_(&ioc_con),
  262. ioc_con_getter_([this]() -> as::io_context& { return *ioc_con_; }),
  263. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  264. config_(std::forward<AcceptorConfig>(config)),
  265. ctx_(force_move(ctx)) {
  266. config_(acceptor_.value());
  267. }
  268. template <typename AsioEndpoint>
  269. server_tls(
  270. AsioEndpoint&& ep,
  271. tls::context&& ctx,
  272. as::io_context& ioc_accept,
  273. as::io_context& ioc_con)
  274. : server_tls(std::forward<AsioEndpoint>(ep), force_move(ctx), ioc_accept, ioc_con, [](as::ip::tcp::acceptor&) {}) {}
  275. template <typename AsioEndpoint, typename AcceptorConfig>
  276. server_tls(
  277. AsioEndpoint&& ep,
  278. tls::context&& ctx,
  279. as::io_context& ioc,
  280. AcceptorConfig&& config)
  281. : server_tls(std::forward<AsioEndpoint>(ep), force_move(ctx), ioc, ioc, std::forward<AcceptorConfig>(config)) {}
  282. template <typename AsioEndpoint>
  283. server_tls(
  284. AsioEndpoint&& ep,
  285. tls::context&& ctx,
  286. as::io_context& ioc)
  287. : server_tls(std::forward<AsioEndpoint>(ep), force_move(ctx), ioc, ioc, [](as::ip::tcp::acceptor&) {}) {}
  288. template <typename AsioEndpoint, typename AcceptorConfig>
  289. server_tls(
  290. AsioEndpoint&& ep,
  291. tls::context&& ctx,
  292. as::io_context& ioc_accept,
  293. std::function<as::io_context&()> ioc_con_getter,
  294. AcceptorConfig&& config = [](as::ip::tcp::acceptor&) {})
  295. : ep_(std::forward<AsioEndpoint>(ep)),
  296. ioc_accept_(ioc_accept),
  297. ioc_con_getter_(force_move(ioc_con_getter)),
  298. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  299. config_(std::forward<AcceptorConfig>(config)),
  300. ctx_(force_move(ctx)) {
  301. config_(acceptor_.value());
  302. }
  303. void listen() {
  304. close_request_ = false;
  305. if (!acceptor_) {
  306. try {
  307. acceptor_.emplace(ioc_accept_, ep_);
  308. config_(acceptor_.value());
  309. }
  310. catch (boost::system::system_error const& e) {
  311. as::post(
  312. ioc_accept_,
  313. [this, ec = e.code()] {
  314. if (h_error_) h_error_(ec, ioc_accept_);
  315. }
  316. );
  317. return;
  318. }
  319. }
  320. do_accept();
  321. }
  322. unsigned short port() const { return acceptor_.value().local_endpoint().port(); }
  323. void close() {
  324. close_request_ = true;
  325. as::post(
  326. ioc_accept_,
  327. [this] {
  328. acceptor_.reset();
  329. }
  330. );
  331. }
  332. void set_accept_handler(accept_handler h = accept_handler()) {
  333. h_accept_ = force_move(h);
  334. }
  335. /**
  336. * @brief Set error handler for listen and accept
  337. * @param h handler
  338. */
  339. void set_error_handler(error_handler h) {
  340. h_error_ =
  341. [h = force_move(h)]
  342. (error_code ec, as::io_context&) {
  343. if (h) h(ec);
  344. };
  345. }
  346. /**
  347. * @brief Set error handler for listen and accept
  348. * @param h handler
  349. */
  350. void set_error_handler(error_handler_with_ioc h = error_handler_with_ioc()) {
  351. h_error_ = force_move(h);
  352. }
  353. /**
  354. * @brief Set error handler
  355. * @param h handler
  356. */
  357. void set_connection_error_handler(connection_error_handler h = connection_error_handler()) {
  358. h_connection_error_ = force_move(h);
  359. }
  360. /**
  361. * @brief Set MQTT protocol version
  362. * @param version accepting protocol version
  363. * If the specific version is set, only set version is accepted.
  364. * If the version is set to protocol_version::undetermined, all versions are accepted.
  365. * Initial value is protocol_version::undetermined.
  366. */
  367. void set_protocol_version(protocol_version version) {
  368. version_ = version;
  369. }
  370. /**
  371. * @bried Set underlying layer connection timeout.
  372. * The timer is set after TCP layer connection accepted.
  373. * The timer is cancelled just before accept handler is called.
  374. * If the timer is fired, the endpoint is removed, the socket is automatically closed.
  375. * The default timeout value is 10 seconds.
  376. * @param timeout timeout value
  377. */
  378. void set_underlying_connect_timeout(std::chrono::steady_clock::duration timeout) {
  379. underlying_connect_timeout_ = force_move(timeout);
  380. }
  381. /**
  382. * @brief Get boost asio ssl context.
  383. * @return ssl context
  384. */
  385. tls::context& get_ssl_context() {
  386. return ctx_;
  387. }
  388. /**
  389. * @brief Get boost asio ssl context.
  390. * @return ssl context
  391. */
  392. tls::context const& get_ssl_context() const {
  393. return ctx_;
  394. }
  395. using verify_cb_t = std::function<bool (bool, boost::asio::ssl::verify_context&, std::shared_ptr<optional<std::string>> const&) >;
  396. void set_verify_callback(verify_cb_t verify_cb) {
  397. verify_cb_with_username_ = verify_cb;
  398. }
  399. private:
  400. void do_accept() {
  401. if (close_request_) return;
  402. auto& ioc_con = ioc_con_getter_();
  403. auto socket = std::make_shared<socket_t>(ioc_con, ctx_);
  404. auto ps = socket.get();
  405. acceptor_.value().async_accept(
  406. ps->lowest_layer(),
  407. [this, socket = force_move(socket), &ioc_con]
  408. (error_code ec) mutable {
  409. if (ec) {
  410. acceptor_.reset();
  411. if (h_error_) h_error_(ec, ioc_con);
  412. return;
  413. }
  414. auto underlying_finished = std::make_shared<bool>(false);
  415. auto connection_error_called = std::make_shared<bool>(false);
  416. auto tim = std::make_shared<as::steady_timer>(ioc_con);
  417. tim->expires_after(underlying_connect_timeout_);
  418. tim->async_wait(
  419. [
  420. this,
  421. socket,
  422. tim,
  423. underlying_finished,
  424. connection_error_called,
  425. &ioc_con
  426. ]
  427. (error_code ec) {
  428. if (*underlying_finished) return;
  429. if (ec) return; // timer cancelled
  430. socket->post(
  431. [this, socket, connection_error_called, &ioc_con] {
  432. boost::system::error_code close_ec;
  433. socket->lowest_layer().close(close_ec);
  434. if (h_connection_error_ && !*connection_error_called) {
  435. h_connection_error_(
  436. boost::system::errc::make_error_code(
  437. boost::system::errc::stream_timeout
  438. ),
  439. ioc_con
  440. );
  441. *connection_error_called = true;
  442. }
  443. }
  444. );
  445. }
  446. );
  447. auto ps = socket.get();
  448. auto username = std::make_shared<optional<std::string>>(); // shared_ptr for username
  449. auto verify_cb_ = [this, username] // copy capture socket shared_ptr
  450. (bool preverified, boost::asio::ssl::verify_context& ctx) {
  451. // user can set username in the callback
  452. return verify_cb_with_username_
  453. ? verify_cb_with_username_(preverified, ctx, username)
  454. : false;
  455. };
  456. ctx_.set_verify_mode(MQTT_NS::tls::verify_peer);
  457. ctx_.set_verify_callback(verify_cb_);
  458. ps->async_handshake(
  459. tls::stream_base::server,
  460. [
  461. this,
  462. socket = force_move(socket),
  463. tim,
  464. underlying_finished,
  465. connection_error_called,
  466. &ioc_con,
  467. username
  468. ]
  469. (error_code ec) mutable {
  470. *underlying_finished = true;
  471. tim->cancel();
  472. if (ec) {
  473. if (h_connection_error_ && !*connection_error_called) {
  474. h_connection_error_(ec, ioc_con);
  475. *connection_error_called = true;
  476. }
  477. return;
  478. }
  479. auto sp = std::make_shared<endpoint_t>(ioc_con, force_move(socket), version_);
  480. sp->set_preauthed_user_name(*username);
  481. if (h_accept_) h_accept_(force_move(sp));
  482. }
  483. );
  484. do_accept();
  485. }
  486. );
  487. }
  488. private:
  489. verify_cb_t verify_cb_with_username_;
  490. as::ip::tcp::endpoint ep_;
  491. as::io_context& ioc_accept_;
  492. as::io_context* ioc_con_ = nullptr;
  493. std::function<as::io_context&()> ioc_con_getter_;
  494. optional<as::ip::tcp::acceptor> acceptor_;
  495. std::function<void(as::ip::tcp::acceptor&)> config_;
  496. bool close_request_{false};
  497. accept_handler h_accept_;
  498. connection_error_handler h_connection_error_;
  499. error_handler_with_ioc h_error_;
  500. tls::context ctx_;
  501. protocol_version version_ = protocol_version::undetermined;
  502. std::chrono::steady_clock::duration underlying_connect_timeout_ = std::chrono::seconds(10);
  503. };
  504. #endif // defined(MQTT_USE_TLS)
  505. #if defined(MQTT_USE_WS)
  506. template <
  507. typename Strand = strand,
  508. typename Mutex = std::mutex,
  509. template<typename...> class LockGuard = std::lock_guard,
  510. std::size_t PacketIdBytes = 2
  511. >
  512. class server_ws {
  513. public:
  514. using socket_t = ws_endpoint<as::ip::tcp::socket, Strand>;
  515. using endpoint_t = callable_overlay<server_endpoint<Mutex, LockGuard, PacketIdBytes>>;
  516. /**
  517. * @brief Accept handler
  518. * @param ep endpoint of the connecting client
  519. */
  520. using accept_handler = std::function<void(std::shared_ptr<endpoint_t> ep)>;
  521. /**
  522. * @brief Error handler during after accepted before connection established
  523. * After this handler called, the next accept will automatically start.
  524. * @param ec error code
  525. * @param ioc_con io_context for incoming connection
  526. */
  527. using connection_error_handler = std::function<void(error_code ec, as::io_context& ioc_con)>;
  528. /**
  529. * @brief Error handler for listen and accpet
  530. * After this handler called, the next accept won't start
  531. * You need to call listen() again if you want to restart accepting.
  532. * @param ec error code
  533. */
  534. using error_handler = std::function<void(error_code ec)>;
  535. /**
  536. * @brief Error handler for listen and accpet
  537. * After this handler called, the next accept won't start
  538. * You need to call listen() again if you want to restart accepting.
  539. * @param ec error code
  540. * @param ioc_con io_context for listen or accept
  541. */
  542. using error_handler_with_ioc = std::function<void(error_code ec, as::io_context& ioc_accept)>;
  543. template <typename AsioEndpoint, typename AcceptorConfig>
  544. server_ws(
  545. AsioEndpoint&& ep,
  546. as::io_context& ioc_accept,
  547. as::io_context& ioc_con,
  548. AcceptorConfig&& config)
  549. : ep_(std::forward<AsioEndpoint>(ep)),
  550. ioc_accept_(ioc_accept),
  551. ioc_con_(&ioc_con),
  552. ioc_con_getter_([this]() -> as::io_context& { return *ioc_con_; }),
  553. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  554. config_(std::forward<AcceptorConfig>(config)) {
  555. config_(acceptor_.value());
  556. }
  557. template <typename AsioEndpoint>
  558. server_ws(
  559. AsioEndpoint&& ep,
  560. as::io_context& ioc_accept,
  561. as::io_context& ioc_con)
  562. : server_ws(std::forward<AsioEndpoint>(ep), ioc_accept, ioc_con, [](as::ip::tcp::acceptor&) {}) {}
  563. template <typename AsioEndpoint, typename AcceptorConfig>
  564. server_ws(
  565. AsioEndpoint&& ep,
  566. as::io_context& ioc,
  567. AcceptorConfig&& config)
  568. : server_ws(std::forward<AsioEndpoint>(ep), ioc, ioc, std::forward<AcceptorConfig>(config)) {}
  569. template <typename AsioEndpoint>
  570. server_ws(
  571. AsioEndpoint&& ep,
  572. as::io_context& ioc)
  573. : server_ws(std::forward<AsioEndpoint>(ep), ioc, ioc, [](as::ip::tcp::acceptor&) {}) {}
  574. template <typename AsioEndpoint, typename AcceptorConfig>
  575. server_ws(
  576. AsioEndpoint&& ep,
  577. as::io_context& ioc_accept,
  578. std::function<as::io_context&()> ioc_con_getter,
  579. AcceptorConfig&& config = [](as::ip::tcp::acceptor&) {})
  580. : ep_(std::forward<AsioEndpoint>(ep)),
  581. ioc_accept_(ioc_accept),
  582. ioc_con_getter_(force_move(ioc_con_getter)),
  583. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  584. config_(std::forward<AcceptorConfig>(config)) {
  585. config_(acceptor_.value());
  586. }
  587. void listen() {
  588. close_request_ = false;
  589. if (!acceptor_) {
  590. try {
  591. acceptor_.emplace(ioc_accept_, ep_);
  592. config_(acceptor_.value());
  593. }
  594. catch (boost::system::system_error const& e) {
  595. as::post(
  596. ioc_accept_,
  597. [this, ec = e.code()] {
  598. if (h_error_) h_error_(ec, ioc_accept_);
  599. }
  600. );
  601. return;
  602. }
  603. }
  604. do_accept();
  605. }
  606. unsigned short port() const { return acceptor_.value().local_endpoint().port(); }
  607. void close() {
  608. close_request_ = true;
  609. as::post(
  610. ioc_accept_,
  611. [this] {
  612. acceptor_.reset();
  613. }
  614. );
  615. }
  616. void set_accept_handler(accept_handler h = accept_handler()) {
  617. h_accept_ = force_move(h);
  618. }
  619. /**
  620. * @brief Set error handler for listen and accept
  621. * @param h handler
  622. */
  623. void set_error_handler(error_handler h) {
  624. h_error_ =
  625. [h = force_move(h)]
  626. (error_code ec, as::io_context&) {
  627. if (h) h(ec);
  628. };
  629. }
  630. /**
  631. * @brief Set error handler for listen and accept
  632. * @param h handler
  633. */
  634. void set_error_handler(error_handler_with_ioc h = error_handler_with_ioc()) {
  635. h_error_ = force_move(h);
  636. }
  637. /**
  638. * @brief Set error handler
  639. * @param h handler
  640. */
  641. void set_connection_error_handler(connection_error_handler h = connection_error_handler()) {
  642. h_connection_error_ = force_move(h);
  643. }
  644. /**
  645. * @brief Set MQTT protocol version
  646. * @param version accepting protocol version
  647. * If the specific version is set, only set version is accepted.
  648. * If the version is set to protocol_version::undetermined, all versions are accepted.
  649. * Initial value is protocol_version::undetermined.
  650. */
  651. void set_protocol_version(protocol_version version) {
  652. version_ = version;
  653. }
  654. /**
  655. * @bried Set underlying layer connection timeout.
  656. * The timer is set after TCP layer connection accepted.
  657. * The timer is cancelled just before accept handler is called.
  658. * If the timer is fired, the endpoint is removed, the socket is automatically closed.
  659. * The default timeout value is 10 seconds.
  660. * @param timeout timeout value
  661. */
  662. void set_underlying_connect_timeout(std::chrono::steady_clock::duration timeout) {
  663. underlying_connect_timeout_ = force_move(timeout);
  664. }
  665. private:
  666. void do_accept() {
  667. if (close_request_) return;
  668. auto& ioc_con = ioc_con_getter_();
  669. auto socket = std::make_shared<socket_t>(ioc_con);
  670. auto ps = socket.get();
  671. acceptor_.value().async_accept(
  672. ps->next_layer(),
  673. [this, socket = force_move(socket), &ioc_con]
  674. (error_code ec) mutable {
  675. if (ec) {
  676. acceptor_.reset();
  677. if (h_error_) h_error_(ec, ioc_con);
  678. return;
  679. }
  680. auto underlying_finished = std::make_shared<bool>(false);
  681. auto connection_error_called = std::make_shared<bool>(false);
  682. auto tim = std::make_shared<as::steady_timer>(ioc_con);
  683. tim->expires_after(underlying_connect_timeout_);
  684. tim->async_wait(
  685. [
  686. this,
  687. socket,
  688. tim,
  689. underlying_finished,
  690. connection_error_called,
  691. &ioc_con
  692. ]
  693. (error_code ec) {
  694. if (*underlying_finished) return;
  695. if (ec) return; // timer cancelled
  696. socket->post(
  697. [this, socket, connection_error_called, &ioc_con] {
  698. boost::system::error_code close_ec;
  699. socket->lowest_layer().close(close_ec);
  700. if (h_connection_error_ && !*connection_error_called) {
  701. h_connection_error_(
  702. boost::system::errc::make_error_code(
  703. boost::system::errc::stream_timeout
  704. ),
  705. ioc_con
  706. );
  707. *connection_error_called = true;
  708. }
  709. }
  710. );
  711. }
  712. );
  713. auto sb = std::make_shared<boost::asio::streambuf>();
  714. auto request = std::make_shared<boost::beast::http::request<boost::beast::http::string_body>>();
  715. auto ps = socket.get();
  716. boost::beast::http::async_read(
  717. ps->next_layer(),
  718. *sb,
  719. *request,
  720. [
  721. this,
  722. socket = force_move(socket),
  723. sb,
  724. request,
  725. tim,
  726. underlying_finished,
  727. connection_error_called,
  728. &ioc_con
  729. ]
  730. (error_code ec, std::size_t) mutable {
  731. if (ec) {
  732. *underlying_finished = true;
  733. tim->cancel();
  734. if (h_connection_error_ && !*connection_error_called) {
  735. h_connection_error_(ec, ioc_con);
  736. *connection_error_called = true;
  737. }
  738. return;
  739. }
  740. if (!boost::beast::websocket::is_upgrade(*request)) {
  741. *underlying_finished = true;
  742. tim->cancel();
  743. if (h_connection_error_ && !*connection_error_called) {
  744. h_connection_error_(
  745. boost::system::errc::make_error_code(
  746. boost::system::errc::protocol_error
  747. ),
  748. ioc_con
  749. );
  750. *connection_error_called = true;
  751. }
  752. return;
  753. }
  754. auto ps = socket.get();
  755. #if BOOST_BEAST_VERSION >= 248
  756. auto it = request->find("Sec-WebSocket-Protocol");
  757. if (it != request->end()) {
  758. ps->set_option(
  759. boost::beast::websocket::stream_base::decorator(
  760. [name = it->name(), value = it->value()] // name is enum, value is boost::string_view
  761. (boost::beast::websocket::response_type& res) {
  762. // This lambda is called before the scope out point *1
  763. res.set(name, value);
  764. }
  765. )
  766. );
  767. }
  768. ps->async_accept(
  769. *request,
  770. [
  771. this,
  772. socket = force_move(socket),
  773. tim,
  774. underlying_finished,
  775. connection_error_called,
  776. &ioc_con
  777. ]
  778. (error_code ec) mutable {
  779. *underlying_finished = true;
  780. tim->cancel();
  781. if (ec) {
  782. if (h_connection_error_ && !*connection_error_called) {
  783. h_connection_error_(ec, ioc_con);
  784. }
  785. *connection_error_called = true;
  786. return;
  787. }
  788. auto sp = std::make_shared<endpoint_t>(ioc_con, force_move(socket), version_);
  789. if (h_accept_) h_accept_(force_move(sp));
  790. }
  791. );
  792. #else // BOOST_BEAST_VERSION >= 248
  793. ps->async_accept_ex(
  794. *request,
  795. [request, connection_error_called]
  796. (boost::beast::websocket::response_type& m) {
  797. auto it = request->find("Sec-WebSocket-Protocol");
  798. if (it != request->end()) {
  799. m.insert(it->name(), it->value());
  800. }
  801. },
  802. [this, socket = force_move(socket), tim, underlying_finished, &ioc_con]
  803. (error_code ec) mutable {
  804. *underlying_finished = true;
  805. tim->cancel();
  806. if (ec) {
  807. if (h_connection_error_ && !*connection_error_called) {
  808. h_connection_error_(ec, ioc_con);
  809. *connection_error_called = true;
  810. }
  811. return;
  812. }
  813. auto sp = std::make_shared<endpoint_t>(ioc_con, force_move(socket), version_);
  814. if (h_accept_) h_accept_(force_move(sp));
  815. }
  816. );
  817. #endif // BOOST_BEAST_VERSION >= 248
  818. // scope out point *1
  819. }
  820. );
  821. do_accept();
  822. }
  823. );
  824. }
  825. private:
  826. as::ip::tcp::endpoint ep_;
  827. as::io_context& ioc_accept_;
  828. as::io_context* ioc_con_ = nullptr;
  829. std::function<as::io_context&()> ioc_con_getter_;
  830. optional<as::ip::tcp::acceptor> acceptor_;
  831. std::function<void(as::ip::tcp::acceptor&)> config_;
  832. bool close_request_{false};
  833. accept_handler h_accept_;
  834. connection_error_handler h_connection_error_;
  835. error_handler_with_ioc h_error_;
  836. protocol_version version_ = protocol_version::undetermined;
  837. std::chrono::steady_clock::duration underlying_connect_timeout_ = std::chrono::seconds(10);
  838. };
  839. #if defined(MQTT_USE_TLS)
  840. template <
  841. typename Strand = strand,
  842. typename Mutex = std::mutex,
  843. template<typename...> class LockGuard = std::lock_guard,
  844. std::size_t PacketIdBytes = 2
  845. >
  846. class server_tls_ws {
  847. public:
  848. using socket_t = ws_endpoint<tls::stream<as::ip::tcp::socket>, Strand>;
  849. using endpoint_t = callable_overlay<server_endpoint<Mutex, LockGuard, PacketIdBytes>>;
  850. /**
  851. * @brief Accept handler
  852. * @param ep endpoint of the connecting client
  853. */
  854. using accept_handler = std::function<void(std::shared_ptr<endpoint_t> ep)>;
  855. /**
  856. * @brief Error handler during after accepted before connection established
  857. * After this handler called, the next accept will automatically start.
  858. * @param ec error code
  859. * @param ioc_con io_context for incoming connection
  860. */
  861. using connection_error_handler = std::function<void(error_code ec, as::io_context& ioc_con)>;
  862. /**
  863. * @brief Error handler for listen and accpet
  864. * After this handler called, the next accept won't start
  865. * You need to call listen() again if you want to restart accepting.
  866. * @param ec error code
  867. */
  868. using error_handler = std::function<void(error_code ec)>;
  869. /**
  870. * @brief Error handler for listen and accpet
  871. * After this handler called, the next accept won't start
  872. * You need to call listen() again if you want to restart accepting.
  873. * @param ec error code
  874. * @param ioc_con io_context for listen or accept
  875. */
  876. using error_handler_with_ioc = std::function<void(error_code ec, as::io_context& ioc_accept)>;
  877. template <typename AsioEndpoint, typename AcceptorConfig>
  878. server_tls_ws(
  879. AsioEndpoint&& ep,
  880. tls::context&& ctx,
  881. as::io_context& ioc_accept,
  882. as::io_context& ioc_con,
  883. AcceptorConfig&& config)
  884. : ep_(std::forward<AsioEndpoint>(ep)),
  885. ioc_accept_(ioc_accept),
  886. ioc_con_(&ioc_con),
  887. ioc_con_getter_([this]() -> as::io_context& { return *ioc_con_; }),
  888. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  889. config_(std::forward<AcceptorConfig>(config)),
  890. ctx_(force_move(ctx)) {
  891. config_(acceptor_.value());
  892. }
  893. template <typename AsioEndpoint>
  894. server_tls_ws(
  895. AsioEndpoint&& ep,
  896. tls::context&& ctx,
  897. as::io_context& ioc_accept,
  898. as::io_context& ioc_con)
  899. : server_tls_ws(std::forward<AsioEndpoint>(ep), force_move(ctx), ioc_accept, ioc_con, [](as::ip::tcp::acceptor&) {}) {}
  900. template <typename AsioEndpoint, typename AcceptorConfig>
  901. server_tls_ws(
  902. AsioEndpoint&& ep,
  903. tls::context&& ctx,
  904. as::io_context& ioc,
  905. AcceptorConfig&& config)
  906. : server_tls_ws(std::forward<AsioEndpoint>(ep), force_move(ctx), ioc, ioc, std::forward<AcceptorConfig>(config)) {}
  907. template <typename AsioEndpoint>
  908. server_tls_ws(
  909. AsioEndpoint&& ep,
  910. tls::context&& ctx,
  911. as::io_context& ioc)
  912. : server_tls_ws(std::forward<AsioEndpoint>(ep), force_move(ctx), ioc, ioc, [](as::ip::tcp::acceptor&) {}) {}
  913. template <typename AsioEndpoint, typename AcceptorConfig>
  914. server_tls_ws(
  915. AsioEndpoint&& ep,
  916. tls::context&& ctx,
  917. as::io_context& ioc_accept,
  918. std::function<as::io_context&()> ioc_con_getter,
  919. AcceptorConfig&& config = [](as::ip::tcp::acceptor&) {})
  920. : ep_(std::forward<AsioEndpoint>(ep)),
  921. ioc_accept_(ioc_accept),
  922. ioc_con_getter_(force_move(ioc_con_getter)),
  923. acceptor_(as::ip::tcp::acceptor(ioc_accept_, ep_)),
  924. config_(std::forward<AcceptorConfig>(config)),
  925. ctx_(force_move(ctx)) {
  926. config_(acceptor_.value());
  927. }
  928. void listen() {
  929. close_request_ = false;
  930. if (!acceptor_) {
  931. try {
  932. acceptor_.emplace(ioc_accept_, ep_);
  933. config_(acceptor_.value());
  934. }
  935. catch (boost::system::system_error const& e) {
  936. as::post(
  937. ioc_accept_,
  938. [this, ec = e.code()] {
  939. if (h_error_) h_error_(ec, ioc_accept_);
  940. }
  941. );
  942. return;
  943. }
  944. }
  945. do_accept();
  946. }
  947. unsigned short port() const { return acceptor_.value().local_endpoint().port(); }
  948. void close() {
  949. close_request_ = true;
  950. as::post(
  951. ioc_accept_,
  952. [this] {
  953. acceptor_.reset();
  954. }
  955. );
  956. }
  957. void set_accept_handler(accept_handler h = accept_handler()) {
  958. h_accept_ = force_move(h);
  959. }
  960. /**
  961. * @brief Set error handler for listen and accept
  962. * @param h handler
  963. */
  964. void set_error_handler(error_handler h) {
  965. h_error_ =
  966. [h = force_move(h)]
  967. (error_code ec, as::io_context&) {
  968. if (h) h(ec);
  969. };
  970. }
  971. /**
  972. * @brief Set error handler for listen and accept
  973. * @param h handler
  974. */
  975. void set_error_handler(error_handler_with_ioc h = error_handler_with_ioc()) {
  976. h_error_ = force_move(h);
  977. }
  978. /**
  979. * @brief Set error handler
  980. * @param h handler
  981. */
  982. void set_connection_error_handler(connection_error_handler h = connection_error_handler()) {
  983. h_connection_error_ = force_move(h);
  984. }
  985. /**
  986. * @brief Set MQTT protocol version
  987. * @param version accepting protocol version
  988. * If the specific version is set, only set version is accepted.
  989. * If the version is set to protocol_version::undetermined, all versions are accepted.
  990. * Initial value is protocol_version::undetermined.
  991. */
  992. void set_protocol_version(protocol_version version) {
  993. version_ = version;
  994. }
  995. /**
  996. * @bried Set underlying layer connection timeout.
  997. * The timer is set after TCP layer connection accepted.
  998. * The timer is cancelled just before accept handler is called.
  999. * If the timer is fired, the endpoint is removed, the socket is automatically closed.
  1000. * The default timeout value is 10 seconds.
  1001. * @param timeout timeout value
  1002. */
  1003. void set_underlying_connect_timeout(std::chrono::steady_clock::duration timeout) {
  1004. underlying_connect_timeout_ = force_move(timeout);
  1005. }
  1006. /**
  1007. * @brief Get boost asio ssl context.
  1008. * @return ssl context
  1009. */
  1010. tls::context& get_ssl_context() {
  1011. return ctx_;
  1012. }
  1013. /**
  1014. * @brief Get boost asio ssl context.
  1015. * @return ssl context
  1016. */
  1017. tls::context const& get_ssl_context() const {
  1018. return ctx_;
  1019. }
  1020. using verify_cb_t = std::function<bool (bool, boost::asio::ssl::verify_context&, std::shared_ptr<optional<std::string>> const&) >;
  1021. void set_verify_callback(verify_cb_t verify_cb) {
  1022. verify_cb_with_username_ = verify_cb;
  1023. }
  1024. private:
  1025. void do_accept() {
  1026. if (close_request_) return;
  1027. auto& ioc_con = ioc_con_getter_();
  1028. auto socket = std::make_shared<socket_t>(ioc_con, ctx_);
  1029. auto ps = socket.get();
  1030. acceptor_.value().async_accept(
  1031. ps->next_layer().next_layer(),
  1032. [this, socket = force_move(socket), &ioc_con]
  1033. (error_code ec) mutable {
  1034. if (ec) {
  1035. acceptor_.reset();
  1036. if (h_error_) h_error_(ec, ioc_con);
  1037. return;
  1038. }
  1039. auto underlying_finished = std::make_shared<bool>(false);
  1040. auto connection_error_called = std::make_shared<bool>(false);
  1041. auto tim = std::make_shared<as::steady_timer>(ioc_con);
  1042. tim->expires_after(underlying_connect_timeout_);
  1043. tim->async_wait(
  1044. [
  1045. this,
  1046. socket,
  1047. tim,
  1048. underlying_finished,
  1049. connection_error_called,
  1050. &ioc_con
  1051. ]
  1052. (error_code ec) {
  1053. if (*underlying_finished) return;
  1054. if (ec) return; // timer cancelled
  1055. socket->post(
  1056. [this, socket, connection_error_called, &ioc_con] {
  1057. boost::system::error_code close_ec;
  1058. socket->lowest_layer().close(close_ec);
  1059. if (h_connection_error_ && !*connection_error_called) {
  1060. h_connection_error_(
  1061. boost::system::errc::make_error_code(
  1062. boost::system::errc::stream_timeout
  1063. ),
  1064. ioc_con
  1065. );
  1066. *connection_error_called = true;
  1067. }
  1068. }
  1069. );
  1070. }
  1071. );
  1072. auto ps = socket.get();
  1073. auto username = std::make_shared<optional<std::string>>(); // shared_ptr for username
  1074. auto verify_cb_ = [this, username] // copy capture socket shared_ptr
  1075. (bool preverified, boost::asio::ssl::verify_context& ctx) {
  1076. // user can set username in the callback
  1077. return verify_cb_with_username_
  1078. ? verify_cb_with_username_(preverified, ctx, username)
  1079. : false;
  1080. };
  1081. ctx_.set_verify_mode(MQTT_NS::tls::verify_peer);
  1082. ctx_.set_verify_callback(verify_cb_);
  1083. ps->next_layer().async_handshake(
  1084. tls::stream_base::server,
  1085. [
  1086. this,
  1087. socket = force_move(socket),
  1088. tim,
  1089. underlying_finished,
  1090. connection_error_called,
  1091. &ioc_con,
  1092. username
  1093. ]
  1094. (error_code ec) mutable {
  1095. if (ec) {
  1096. *underlying_finished = true;
  1097. tim->cancel();
  1098. return;
  1099. }
  1100. auto sb = std::make_shared<boost::asio::streambuf>();
  1101. auto request = std::make_shared<boost::beast::http::request<boost::beast::http::string_body>>();
  1102. auto ps = socket.get();
  1103. boost::beast::http::async_read(
  1104. ps->next_layer(),
  1105. *sb,
  1106. *request,
  1107. [
  1108. this,
  1109. socket = force_move(socket),
  1110. sb,
  1111. request,
  1112. tim,
  1113. underlying_finished,
  1114. connection_error_called,
  1115. &ioc_con,
  1116. username
  1117. ]
  1118. (error_code ec, std::size_t) mutable {
  1119. if (ec) {
  1120. *underlying_finished = true;
  1121. tim->cancel();
  1122. if (h_connection_error_ && !*connection_error_called) {
  1123. h_connection_error_(ec, ioc_con);
  1124. *connection_error_called = true;
  1125. }
  1126. return;
  1127. }
  1128. if (!boost::beast::websocket::is_upgrade(*request)) {
  1129. *underlying_finished = true;
  1130. tim->cancel();
  1131. if (h_connection_error_ && !*connection_error_called) {
  1132. h_connection_error_(
  1133. boost::system::errc::make_error_code(
  1134. boost::system::errc::protocol_error
  1135. ),
  1136. ioc_con
  1137. );
  1138. *connection_error_called = true;
  1139. }
  1140. return;
  1141. }
  1142. auto ps = socket.get();
  1143. #if BOOST_BEAST_VERSION >= 248
  1144. auto it = request->find("Sec-WebSocket-Protocol");
  1145. if (it != request->end()) {
  1146. ps->set_option(
  1147. boost::beast::websocket::stream_base::decorator(
  1148. [name = it->name(), value = it->value()] // name is enum, value is boost::string_view
  1149. (boost::beast::websocket::response_type& res) {
  1150. // This lambda is called before the scope out point *1
  1151. res.set(name, value);
  1152. }
  1153. )
  1154. );
  1155. }
  1156. ps->async_accept(
  1157. *request,
  1158. [
  1159. this,
  1160. socket = force_move(socket),
  1161. tim,
  1162. underlying_finished,
  1163. connection_error_called,
  1164. &ioc_con,
  1165. username
  1166. ]
  1167. (error_code ec) mutable {
  1168. *underlying_finished = true;
  1169. tim->cancel();
  1170. if (ec) {
  1171. if (h_connection_error_ && !*connection_error_called) {
  1172. h_connection_error_(ec, ioc_con);
  1173. *connection_error_called = true;
  1174. }
  1175. return;
  1176. }
  1177. auto sp = std::make_shared<endpoint_t>(ioc_con, force_move(socket), version_);
  1178. sp->set_preauthed_user_name(*username);
  1179. if (h_accept_) h_accept_(force_move(sp));
  1180. }
  1181. );
  1182. #else // BOOST_BEAST_VERSION >= 248
  1183. ps->async_accept_ex(
  1184. *request,
  1185. [request]
  1186. (boost::beast::websocket::response_type& m) {
  1187. auto it = request->find("Sec-WebSocket-Protocol");
  1188. if (it != request->end()) {
  1189. m.insert(it->name(), it->value());
  1190. }
  1191. },
  1192. [
  1193. this,
  1194. socket = force_move(socket),
  1195. tim,
  1196. underlying_finished,
  1197. connection_error_called,
  1198. &ioc_con,
  1199. username
  1200. ]
  1201. (error_code ec) mutable {
  1202. *underlying_finished = true;
  1203. tim->cancel();
  1204. if (ec) {
  1205. if (h_connection_error_ && *connection_error_called) {
  1206. h_connection_error_(ec, ioc_con);
  1207. *connection_error_called = true;
  1208. }
  1209. return;
  1210. }
  1211. // TODO: The use of force_move on this line of code causes
  1212. // a static assertion that socket is a const object when
  1213. // TLS is enabled, and WS is enabled, with Boost 1.70, and gcc 8.3.0
  1214. auto sp = std::make_shared<endpoint_t>(ioc_con, socket, version_);
  1215. sp->set_preauthed_user_name(*username);
  1216. if (h_accept_) h_accept_(force_move(sp));
  1217. }
  1218. );
  1219. #endif // BOOST_BEAST_VERSION >= 248
  1220. // scope out point *1
  1221. }
  1222. );
  1223. }
  1224. );
  1225. do_accept();
  1226. }
  1227. );
  1228. }
  1229. private:
  1230. verify_cb_t verify_cb_with_username_;
  1231. as::ip::tcp::endpoint ep_;
  1232. as::io_context& ioc_accept_;
  1233. as::io_context* ioc_con_ = nullptr;
  1234. std::function<as::io_context&()> ioc_con_getter_;
  1235. optional<as::ip::tcp::acceptor> acceptor_;
  1236. std::function<void(as::ip::tcp::acceptor&)> config_;
  1237. bool close_request_{false};
  1238. accept_handler h_accept_;
  1239. connection_error_handler h_connection_error_;
  1240. error_handler_with_ioc h_error_;
  1241. tls::context ctx_;
  1242. protocol_version version_ = protocol_version::undetermined;
  1243. std::chrono::steady_clock::duration underlying_connect_timeout_ = std::chrono::seconds(10);
  1244. };
  1245. #endif // defined(MQTT_USE_TLS)
  1246. #endif // defined(MQTT_USE_WS)
  1247. } // namespace MQTT_NS
  1248. #endif // MQTT_SERVER_HPP