diff --git a/src/mesh/CryptoEngine.cpp b/src/mesh/CryptoEngine.cpp index f4abff92e..96ee4cb24 100644 --- a/src/mesh/CryptoEngine.cpp +++ b/src/mesh/CryptoEngine.cpp @@ -263,15 +263,41 @@ bool CryptoEngine::setCryptoSharedSecret(meshtastic_UserLite_public_key_t pubkey uint32_t lookupKey; memcpy(&lookupKey, pubkey.bytes, sizeof(lookupKey)); - // See if we cached the secret already - auto iter = sharedSecretCache.find(lookupKey); - if (iter != sharedSecretCache.end()) { - // Cache hit! Copy it into shared_key. - CachedSharedSecret &entry = iter->second; - memcpy(shared_key, entry.shared_secret, 32); - // Update the last used timestamp - entry.last_used = now; - return true; + uint16_t oldestDelta = 0; + CachedSharedSecret &oldestEntry = sharedSecretCache[0]; + for (size_t i = 0; i < MAX_CACHED_SHARED_SECRETS; i++) { + CachedSharedSecret &entry = sharedSecretCache[i]; + if (entry.lookup_key == lookupKey) { + // Cache hit! Copy it into shared_key. + memcpy(shared_key, entry.shared_secret, 32); + // Update the last used timestamp + entry.last_used = now; + return true; + } + + if (oldestEntry.lookup_key == 0) { + // We already have a valid slot to insert into. Keep looking for a cache hit. + continue; + } + + if (entry.lookup_key == 0) { + // This entry is empty. We can insert into it later, if needed. + oldestEntry = entry; + continue; + } + + // Track the oldest entry in case the cache is full. + uint16_t delta = 0; + if (now >= entry.last_used) { + delta = now - entry.last_used; + } else { + // Assume a larger last used timestamp is further in the past + delta = uint16_t(0x100) + now - entry.last_used; + } + if (delta > oldestDelta) { + oldestEntry = entry; + oldestDelta = delta; + } } // Cache miss. Generate the shared secret. @@ -280,34 +306,10 @@ bool CryptoEngine::setCryptoSharedSecret(meshtastic_UserLite_public_key_t pubkey } hash(shared_key, 32); - // If the cache will grow too large, remove the oldest entry first - if (sharedSecretCache.size() >= MAX_CACHED_SHARED_SECRETS) { - uint16_t oldestDelta = 0; - uint32_t oldestKey = sharedSecretCache.begin()->first; - for (const auto &p : sharedSecretCache) { - const uint32_t key = p.first; - const CachedSharedSecret &entry = p.second; - - uint16_t delta = 0; - if (now >= entry.last_used) { - delta = now - entry.last_used; - } else { - // Assume a larger last used timestamp is further in the past - delta = uint16_t(0x100) + now - entry.last_used; - } - if (delta > oldestDelta) { - oldestKey = key; - oldestDelta = delta; - } - } - sharedSecretCache.erase(oldestKey); - } - - // Now insert the calculated shared secret - CachedSharedSecret entry; - entry.last_used = now; - memcpy(entry.shared_secret, shared_key, 32); - sharedSecretCache.insert({lookupKey, entry}); + // Insert the calculated shared secret into the cache, overwriting an old entry if needed. + oldestEntry.lookup_key = lookupKey; + oldestEntry.last_used = now; + memcpy(oldestEntry.shared_secret, shared_key, 32); return true; } diff --git a/src/mesh/CryptoEngine.h b/src/mesh/CryptoEngine.h index 6c2c70be5..98e40dc72 100644 --- a/src/mesh/CryptoEngine.h +++ b/src/mesh/CryptoEngine.h @@ -17,6 +17,7 @@ struct CryptoKey { }; struct CachedSharedSecret { + uint32_t lookup_key; uint8_t shared_secret[32]; uint8_t last_used; }; @@ -114,7 +115,7 @@ class CryptoEngine /** * Cache mapping peers' public keys -> {shared_secret, last_used} */ - std::unordered_map sharedSecretCache; + CachedSharedSecret sharedSecretCache[MAX_CACHED_SHARED_SECRETS] = {0}; /** * Set cryptographic (hashed) shared_key calculated from the given pubkey diff --git a/test/test_crypto/test_main.cpp b/test/test_crypto/test_main.cpp index 930887145..384bbd6ef 100644 --- a/test/test_crypto/test_main.cpp +++ b/test/test_crypto/test_main.cpp @@ -7,12 +7,13 @@ struct TestCryptoEngine { static bool getCachedSecret(uint32_t lookupKey, CachedSharedSecret &entry) { - auto iter = crypto->sharedSecretCache.find(lookupKey); - if (iter == crypto->sharedSecretCache.end()) { - return false; + for (size_t i = 0; i < MAX_CACHED_SHARED_SECRETS; i++) { + entry = crypto->sharedSecretCache[i]; + if (entry.lookup_key == lookupKey) { + return true; + } } - entry = iter->second; - return true; + return false; } };