From 3b0232de1b6282eacfbff6e50b68fca7e67b8511 Mon Sep 17 00:00:00 2001 From: Eric Severance Date: Mon, 17 Feb 2025 13:03:44 -0800 Subject: [PATCH] Validate MQTT config by testing a connection (#6076) --- src/mqtt/MQTT.cpp | 158 +++++++++++++++++++++++++--------------- src/mqtt/MQTT.h | 30 ++++---- test/test_mqtt/MQTT.cpp | 63 +++++++++++++++- 3 files changed, 173 insertions(+), 78 deletions(-) diff --git a/src/mqtt/MQTT.cpp b/src/mqtt/MQTT.cpp index 67eba82a6..5f16f909f 100644 --- a/src/mqtt/MQTT.cpp +++ b/src/mqtt/MQTT.cpp @@ -41,7 +41,6 @@ MQTT *mqtt; namespace { constexpr int reconnectMax = 5; -constexpr uint16_t mqttPort = 1883; // FIXME - this size calculation is super sloppy, but it will go away once we dynamically alloc meshpackets static uint8_t bytes[meshtastic_MqttClientProxyMessage_size + 30]; // 12 for channel name and 16 for nodeid @@ -251,6 +250,68 @@ bool isDefaultServer(const String &host) { return host.length() == 0 || host == default_mqtt_address; } + +struct PubSubConfig { + explicit PubSubConfig(const meshtastic_ModuleConfig_MQTTConfig &config) + { + if (*config.address) { + serverAddr = config.address; + mqttUsername = config.username; + mqttPassword = config.password; + } + if (config.tls_enabled) { + serverPort = 8883; + } + std::tie(serverAddr, serverPort) = parseHostAndPort(serverAddr.c_str(), serverPort); + } + + // Defaults + static constexpr uint16_t defaultPort = 1883; + uint16_t serverPort = defaultPort; + String serverAddr = default_mqtt_address; + const char *mqttUsername = default_mqtt_username; + const char *mqttPassword = default_mqtt_password; +}; + +#if HAS_NETWORKING +bool connectPubSub(const PubSubConfig &config, PubSubClient &pubSub, Client &client) +{ + pubSub.setBufferSize(1024); + pubSub.setClient(client); + pubSub.setServer(config.serverAddr.c_str(), config.serverPort); + + LOG_INFO("Connecting directly to MQTT server %s, port: %d, username: %s, password: %s", config.serverAddr.c_str(), + config.serverPort, config.mqttUsername, config.mqttPassword); + + const bool connected = pubSub.connect(owner.id, config.mqttUsername, config.mqttPassword); + if (connected) { + LOG_INFO("MQTT connected"); + } else { + LOG_WARN("Failed to connect to MQTT server"); + } + return connected; +} +#endif + +inline bool isConnectedToNetwork() +{ +#if HAS_WIFI + return WiFi.isConnected(); +#elif HAS_ETHERNET + return Ethernet.linkStatus() == LinkON; +#else + return false; +#endif +} + +/** return true if we have a channel that wants uplink/downlink or map reporting is enabled + */ +bool wantsLink() +{ + const bool hasChannelorMapReport = + moduleConfig.mqtt.enabled && (moduleConfig.mqtt.map_reporting_enabled || channels.anyMqttEnabled()); + return hasChannelorMapReport && (moduleConfig.mqtt.proxy_to_client_enabled || isConnectedToNetwork()); +} } // namespace void MQTT::mqttCallback(char *topic, byte *payload, unsigned int length) @@ -413,46 +474,18 @@ void MQTT::reconnect() return; // Don't try to connect directly to the server } #if HAS_NETWORKING - // Defaults - int serverPort = mqttPort; - const char *serverAddr = default_mqtt_address; - const char *mqttUsername = default_mqtt_username; - const char *mqttPassword = default_mqtt_password; + const PubSubConfig config(moduleConfig.mqtt); MQTTClient *clientConnection = mqttClient.get(); - - if (*moduleConfig.mqtt.address) { - serverAddr = moduleConfig.mqtt.address; - mqttUsername = moduleConfig.mqtt.username; - mqttPassword = moduleConfig.mqtt.password; - } -#if HAS_WIFI && !defined(ARCH_PORTDUINO) && !defined(CONFIG_IDF_TARGET_ESP32C6) +#if MQTT_SUPPORTS_TLS if (moduleConfig.mqtt.tls_enabled) { - // change default for encrypted to 8883 - try { - serverPort = 8883; - wifiSecureClient.setInsecure(); - LOG_INFO("Use TLS-encrypted session"); - clientConnection = &wifiSecureClient; - } catch (const std::exception &e) { - LOG_ERROR("MQTT ERROR: %s", e.what()); - } + mqttClientTLS.setInsecure(); + LOG_INFO("Use TLS-encrypted session"); + clientConnection = &mqttClientTLS; } else { LOG_INFO("Use non-TLS-encrypted session"); } #endif - std::pair hostAndPort = parseHostAndPort(serverAddr, serverPort); - serverAddr = hostAndPort.first.c_str(); - serverPort = hostAndPort.second; - pubSub.setServer(serverAddr, serverPort); - pubSub.setBufferSize(1024); - - LOG_INFO("Connect directly to MQTT server %s, port: %d, username: %s, password: %s", serverAddr, serverPort, mqttUsername, - mqttPassword); - - pubSub.setClient(*clientConnection); - bool connected = pubSub.connect(owner.id, mqttUsername, mqttPassword); - if (connected) { - LOG_INFO("MQTT connected"); + if (connectPubSub(config, pubSub, *clientConnection)) { enabled = true; // Start running background process again runASAP = true; reconnectCount = 0; @@ -507,23 +540,6 @@ void MQTT::sendSubscriptions() #endif } -bool MQTT::wantsLink() const -{ - bool hasChannelorMapReport = - moduleConfig.mqtt.enabled && (moduleConfig.mqtt.map_reporting_enabled || channels.anyMqttEnabled()); - - if (hasChannelorMapReport && moduleConfig.mqtt.proxy_to_client_enabled) - return true; - -#if HAS_WIFI - return hasChannelorMapReport && WiFi.isConnected(); -#endif -#if HAS_ETHERNET - return hasChannelorMapReport && Ethernet.linkStatus() == LinkON; -#endif - return false; -} - int32_t MQTT::runOnce() { #if HAS_NETWORKING @@ -567,18 +583,42 @@ int32_t MQTT::runOnce() return 30000; } -bool MQTT::isValidConfig(const meshtastic_ModuleConfig_MQTTConfig &config) +bool MQTT::isValidConfig(const meshtastic_ModuleConfig_MQTTConfig &config, MQTTClient *client) { - String host; - uint16_t port; - std::tie(host, port) = parseHostAndPort(config.address, mqttPort); - const bool defaultServer = isDefaultServer(host); + const PubSubConfig parsed(config); + + if (config.enabled && !config.proxy_to_client_enabled) { +#if HAS_NETWORKING + std::unique_ptr clientConnection; + if (config.tls_enabled) { +#if MQTT_SUPPORTS_TLS + MQTTClientTLS *tlsClient = new MQTTClientTLS; + clientConnection.reset(tlsClient); + tlsClient->setInsecure(); +#else + LOG_ERROR("Invalid MQTT config: tls_enabled is not supported on this node"); + return false; +#endif + } else { + clientConnection.reset(new MQTTClient); + } + std::unique_ptr pubSub(new PubSubClient); + if (isConnectedToNetwork()) { + return connectPubSub(parsed, *pubSub, (client != nullptr) ? *client : *clientConnection); + } +#else + LOG_ERROR("Invalid MQTT config: proxy_to_client_enabled must be enabled on nodes that do not have a network"); + return false; +#endif + } + + const bool defaultServer = isDefaultServer(parsed.serverAddr); if (defaultServer && config.tls_enabled) { LOG_ERROR("Invalid MQTT config: TLS was enabled, but the default server does not support TLS"); return false; } - if (defaultServer && port != mqttPort) { - LOG_ERROR("Invalid MQTT config: Unsupported port '%d' for the default MQTT server", port); + if (defaultServer && parsed.serverPort != PubSubConfig::defaultPort) { + LOG_ERROR("Invalid MQTT config: Unsupported port '%d' for the default MQTT server", parsed.serverPort); return false; } return true; diff --git a/src/mqtt/MQTT.h b/src/mqtt/MQTT.h index f7e3864f8..5cda90218 100644 --- a/src/mqtt/MQTT.h +++ b/src/mqtt/MQTT.h @@ -10,12 +10,10 @@ #endif #if HAS_WIFI #include -#if !defined(ARCH_PORTDUINO) -#if defined(ESP_ARDUINO_VERSION_MAJOR) && ESP_ARDUINO_VERSION_MAJOR < 3 +#if __has_include() #include #endif #endif -#endif #if HAS_ETHERNET #include #endif @@ -61,7 +59,8 @@ class MQTT : private concurrency::OSThread bool isUsingDefaultServer() { return isConfiguredForDefaultServer; } - static bool isValidConfig(const meshtastic_ModuleConfig_MQTTConfig &config); + /// Validate the meshtastic_ModuleConfig_MQTTConfig. + static bool isValidConfig(const meshtastic_ModuleConfig_MQTTConfig &config) { return isValidConfig(config, nullptr); } protected: struct QueueEntry { @@ -78,22 +77,23 @@ class MQTT : private concurrency::OSThread #ifndef PIO_UNIT_TESTING private: #endif - // supposedly the current version is busted: - // http://www.iotsharing.com/2017/08/how-to-use-esp32-mqtts-with-mqtts-mosquitto-broker-tls-ssl.html #if HAS_WIFI using MQTTClient = WiFiClient; -#if !defined(ARCH_PORTDUINO) -#if (defined(ESP_ARDUINO_VERSION_MAJOR) && ESP_ARDUINO_VERSION_MAJOR < 3) || defined(RPI_PICO) - WiFiClientSecure wifiSecureClient; +#if __has_include() + using MQTTClientTLS = WiFiClientSecure; +#define MQTT_SUPPORTS_TLS 1 #endif -#endif -#endif -#if HAS_ETHERNET +#elif HAS_ETHERNET using MQTTClient = EthernetClient; +#else + using MQTTClient = void; #endif #if HAS_NETWORKING std::unique_ptr mqttClient; +#if MQTT_SUPPORTS_TLS + MQTTClientTLS mqttClientTLS; +#endif PubSubClient pubSub; explicit MQTT(std::unique_ptr mqttClient); #endif @@ -109,10 +109,6 @@ class MQTT : private concurrency::OSThread uint32_t map_position_precision = default_map_position_precision; uint32_t map_publish_interval_msecs = default_map_publish_interval_secs * 1000; - /** return true if we have a channel that wants uplink/downlink or map reporting is enabled - */ - bool wantsLink() const; - /** Attempt to connect to server if necessary */ void reconnect(); @@ -124,6 +120,8 @@ class MQTT : private concurrency::OSThread /// Callback for direct mqtt subscription messages static void mqttCallback(char *topic, byte *payload, unsigned int length); + static bool isValidConfig(const meshtastic_ModuleConfig_MQTTConfig &config, MQTTClient *client); + /// Called when a new publish arrives from the MQTT server void onReceive(char *topic, byte *payload, size_t length); diff --git a/test/test_mqtt/MQTT.cpp b/test/test_mqtt/MQTT.cpp index c00922548..50a98001a 100644 --- a/test/test_mqtt/MQTT.cpp +++ b/test/test_mqtt/MQTT.cpp @@ -94,6 +94,7 @@ class MockPubSubServer : public WiFiClient int connect(IPAddress ip, uint16_t port) override { + port_ = port; if (refuseConnection_) return 0; connected_ = true; @@ -101,6 +102,8 @@ class MockPubSubServer : public WiFiClient } int connect(const char *host, uint16_t port) override { + host_ = host; + port_ = port; if (refuseConnection_) return 0; connected_ = true; @@ -197,6 +200,8 @@ class MockPubSubServer : public WiFiClient bool connected_ = false; bool refuseConnection_ = false; // Simulate a failed connection. uint32_t ipAddress_ = 0x01010101; // IP address of the MQTT server. + std::string host_; // Requested host. + uint16_t port_; // Requested port. std::list buffer_; // Buffer of messages for the pubSub client to receive. std::string command_; // Current command received from the pubSub client. std::set subscriptions_; // Topics that the pubSub client has subscribed to. @@ -242,6 +247,7 @@ class MQTTUnitTest : public MQTT mqttClient.release(); delete pubsub; } + using MQTT::isValidConfig; using MQTT::reconnect; int queueSize() { return mqttQueue.numUsed(); } void reportToMap(std::optional precision = std::nullopt) @@ -801,13 +807,25 @@ void test_customMqttRoot(void) } // Empty configuration is valid. -void test_configurationEmptyIsValid(void) +void test_configEmptyIsValid(void) { - meshtastic_ModuleConfig_MQTTConfig config; + meshtastic_ModuleConfig_MQTTConfig config = {}; TEST_ASSERT_TRUE(MQTT::isValidConfig(config)); } +// Empty 'enabled' configuration is valid. +void test_configEnabledEmptyIsValid(void) +{ + meshtastic_ModuleConfig_MQTTConfig config = {.enabled = true}; + MockPubSubServer client; + + TEST_ASSERT_TRUE(MQTTUnitTest::isValidConfig(config, &client)); + TEST_ASSERT_TRUE(client.connected_); + TEST_ASSERT_EQUAL_STRING(default_mqtt_address, client.host_.c_str()); + TEST_ASSERT_EQUAL(1883, client.port_); +} + // Configuration with the default server is valid. void test_configWithDefaultServer(void) { @@ -832,6 +850,41 @@ void test_configWithDefaultServerAndInvalidTLSEnabled(void) TEST_ASSERT_FALSE(MQTT::isValidConfig(config)); } +// isValidConfig connects to a custom host and port. +void test_configCustomHostAndPort(void) +{ + meshtastic_ModuleConfig_MQTTConfig config = {.enabled = true, .address = "server:1234"}; + MockPubSubServer client; + + TEST_ASSERT_TRUE(MQTTUnitTest::isValidConfig(config, &client)); + TEST_ASSERT_TRUE(client.connected_); + TEST_ASSERT_EQUAL_STRING("server", client.host_.c_str()); + TEST_ASSERT_EQUAL(1234, client.port_); +} + +// isValidConfig returns false if a connection cannot be established. +void test_configWithConnectionFailure(void) +{ + meshtastic_ModuleConfig_MQTTConfig config = {.enabled = true, .address = "server"}; + MockPubSubServer client; + client.refuseConnection_ = true; + + TEST_ASSERT_FALSE(MQTTUnitTest::isValidConfig(config, &client)); +} + +// isValidConfig returns true when tls_enabled is supported, or false otherwise. +void test_configWithTLSEnabled(void) +{ + meshtastic_ModuleConfig_MQTTConfig config = {.enabled = true, .address = "server", .tls_enabled = true}; + MockPubSubServer client; + +#if MQTT_SUPPORTS_TLS + TEST_ASSERT_TRUE(MQTTUnitTest::isValidConfig(config, &client)); +#else + TEST_ASSERT_FALSE(MQTTUnitTest::isValidConfig(config, &client)); +#endif +} + void setup() { initializeTestEnvironment(); @@ -875,10 +928,14 @@ void setup() RUN_TEST(test_enabled); RUN_TEST(test_disabled); RUN_TEST(test_customMqttRoot); - RUN_TEST(test_configurationEmptyIsValid); + RUN_TEST(test_configEmptyIsValid); + RUN_TEST(test_configEnabledEmptyIsValid); RUN_TEST(test_configWithDefaultServer); RUN_TEST(test_configWithDefaultServerAndInvalidPort); RUN_TEST(test_configWithDefaultServerAndInvalidTLSEnabled); + RUN_TEST(test_configCustomHostAndPort); + RUN_TEST(test_configWithConnectionFailure); + RUN_TEST(test_configWithTLSEnabled); exit(UNITY_END()); } #else