diff --git a/dll/shellext/wlanwiz/main.h b/dll/shellext/wlanwiz/main.h index 959b617f8bd55..9f24ec2a7347f 100644 --- a/dll/shellext/wlanwiz/main.h +++ b/dll/shellext/wlanwiz/main.h @@ -6,6 +6,8 @@ */ #pragma once +#include + #include #include #include @@ -168,6 +170,9 @@ class CWlanWizard : public CDialogImpl HWND CreateToolTip(int toolID); static DWORD WINAPI ScanNetworksThread(_In_ LPVOID lpParameter); + void TryInsertToAdHoc(std::set& setAdHoc, DWORD dwIndex); + void TryInsertToKnown(std::set& setProfiles, DWORD dwIndex); + DWORD TryFindConnected(DWORD dwIndex); LRESULT OnInitDialog(UINT nMsg, WPARAM wParam, LPARAM lParam, BOOL& bHandled); LRESULT OnDrawItem(UINT nMsg, WPARAM wParam, LPARAM lParam, BOOL& bHandled); diff --git a/dll/shellext/wlanwiz/scan.cpp b/dll/shellext/wlanwiz/scan.cpp index f930cd8cf9d45..80062180b2561 100644 --- a/dll/shellext/wlanwiz/scan.cpp +++ b/dll/shellext/wlanwiz/scan.cpp @@ -9,10 +9,51 @@ #include #include -#include #include #include +/* Convert SSID from UTF-8 to UTF-16 */ +static ATL::CStringW APNameToUnicode(PDOT11_SSID dot11Ssid) +{ + int iSSIDLengthWide = MultiByteToWideChar(CP_UTF8, 0, reinterpret_cast(dot11Ssid->ucSSID), dot11Ssid->uSSIDLength, NULL, 0); + + ATL::CStringW cswSSID = ATL::CStringW(L"", iSSIDLengthWide); + MultiByteToWideChar(CP_UTF8, 0, reinterpret_cast(dot11Ssid->ucSSID), dot11Ssid->uSSIDLength, cswSSID.GetBuffer(), iSSIDLengthWide); + + return cswSSID; +} + +void CWlanWizard::TryInsertToKnown(std::set& setProfiles, DWORD dwIndex) +{ + PWLAN_AVAILABLE_NETWORK pWlanNetwork = &this->lstWlanNetworks->Network[dwIndex]; + + if ((pWlanNetwork->dwFlags & WLAN_AVAILABLE_NETWORK_HAS_PROFILE) == WLAN_AVAILABLE_NETWORK_HAS_PROFILE) + { + std::wstring_view wsvSSID = APNameToUnicode(&pWlanNetwork->dot11Ssid); + + if (wsvSSID == pWlanNetwork->strProfileName) + setProfiles.insert(dwIndex); + } +} + +void CWlanWizard::TryInsertToAdHoc(std::set& setAdHoc, DWORD dwIndex) +{ + PWLAN_AVAILABLE_NETWORK pWlanNetwork = &this->lstWlanNetworks->Network[dwIndex]; + + if (pWlanNetwork->dot11BssType == dot11_BSS_type_independent) + setAdHoc.insert(dwIndex); +} + +DWORD CWlanWizard::TryFindConnected(DWORD dwIndex) +{ + PWLAN_AVAILABLE_NETWORK pWlanNetwork = &this->lstWlanNetworks->Network[dwIndex]; + + if ((pWlanNetwork->dwFlags & WLAN_AVAILABLE_NETWORK_CONNECTED) == WLAN_AVAILABLE_NETWORK_CONNECTED) + return dwIndex; + + return MAXDWORD; +} + LRESULT CWlanWizard::OnScanNetworks(WORD wNotifyCode, WORD wID, HWND hWndCtl, BOOL& bHandled) { MSG msg; @@ -90,43 +131,64 @@ LRESULT CWlanWizard::OnScanNetworks(WORD wNotifyCode, WORD wID, HWND hWndCtl, BO m_SidebarButtonAS.EnableWindow(); m_SidebarButtonSN.EnableWindow(); - - DPRINT("Discovered %lu access points\n", this->lstWlanNetworks->dwNumberOfItems); if (this->lstWlanNetworks->dwNumberOfItems > 0) { auto vecIndexesBySignalQuality = std::vector(this->lstWlanNetworks->dwNumberOfItems); DWORD dwConnectedTo = MAXDWORD; std::set setDiscoveredAdHocIndexes; + std::set setAPsWithProfiles; std::iota(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), 0); /* Sort networks by signal level */ - std::sort(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), [&](DWORD left, DWORD right) + std::sort(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), [&](auto left, auto right) { - WLAN_AVAILABLE_NETWORK wlanLeft = this->lstWlanNetworks->Network[left]; - WLAN_AVAILABLE_NETWORK wlanRight = this->lstWlanNetworks->Network[right]; + TryInsertToAdHoc(setDiscoveredAdHocIndexes, left); + TryInsertToAdHoc(setDiscoveredAdHocIndexes, right); + + /* Try to determine if we are connected currently to anything. + * Once found, these two steps are skipped. */ + if (dwConnectedTo == MAXDWORD) + dwConnectedTo = TryFindConnected(left); - if (wlanLeft.dot11BssType == dot11_BSS_type_independent) - setDiscoveredAdHocIndexes.insert(left); + if (dwConnectedTo == MAXDWORD) + dwConnectedTo = TryFindConnected(right); - if (wlanLeft.dwFlags & WLAN_AVAILABLE_NETWORK_CONNECTED) - dwConnectedTo = left; + /* Count network as known if it fully matches SSID with profile name */ + TryInsertToKnown(setAPsWithProfiles, left); + TryInsertToKnown(setAPsWithProfiles, right); - return wlanLeft.wlanSignalQuality > wlanRight.wlanSignalQuality; + return this->lstWlanNetworks->Network[left].wlanSignalQuality > this->lstWlanNetworks->Network[right].wlanSignalQuality; }); - /* Shift all ad hoc networks to end */ - if (setDiscoveredAdHocIndexes.size() > 0) + /* Remove networks that do not have profile name exactly matching SSID */ + for (const auto& dwKnownAPIdx : setAPsWithProfiles) { - for (const auto& dwAdHocIdx : setDiscoveredAdHocIndexes) + PWLAN_AVAILABLE_NETWORK wlanNetWithProfile = &this->lstWlanNetworks->Network[dwKnownAPIdx]; + + vecIndexesBySignalQuality.erase(std::remove_if(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), [&](const DWORD& dwAP) { - auto iter = std::find(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), dwAdHocIdx); + if (dwKnownAPIdx == dwAP) + return false; - if (iter != vecIndexesBySignalQuality.end()) - { - auto idx = iter - vecIndexesBySignalQuality.begin(); - std::rotate(vecIndexesBySignalQuality.begin() + idx, vecIndexesBySignalQuality.begin() + idx + 1, vecIndexesBySignalQuality.end()); - } + bool bProfileNameIsSSID = std::wstring_view(wlanNetWithProfile->strProfileName) == std::wstring_view(APNameToUnicode(&this->lstWlanNetworks->Network[dwAP].dot11Ssid)); + bool bHasProfile = (this->lstWlanNetworks->Network[dwAP].dwFlags & WLAN_AVAILABLE_NETWORK_HAS_PROFILE) != 0; + + return bProfileNameIsSSID && !bHasProfile; + }), vecIndexesBySignalQuality.end()); + } + + DPRINT("Discovered %lu access points (%d are known)\n", this->lstWlanNetworks->dwNumberOfItems, setAPsWithProfiles.size()); + + /* Shift all ad hoc networks to end */ + for (const auto& dwAdHocIdx : setDiscoveredAdHocIndexes) + { + auto iter = std::find(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), dwAdHocIdx); + + if (iter != vecIndexesBySignalQuality.end()) + { + auto idx = iter - vecIndexesBySignalQuality.begin(); + std::rotate(vecIndexesBySignalQuality.begin() + idx, vecIndexesBySignalQuality.begin() + idx + 1, vecIndexesBySignalQuality.end()); } } @@ -134,10 +196,14 @@ LRESULT CWlanWizard::OnScanNetworks(WORD wNotifyCode, WORD wID, HWND hWndCtl, BO if (dwConnectedTo != MAXDWORD) { auto connectedIdx = std::find(vecIndexesBySignalQuality.begin(), vecIndexesBySignalQuality.end(), dwConnectedTo) - vecIndexesBySignalQuality.begin(); - std::rotate(vecIndexesBySignalQuality.begin() + connectedIdx, vecIndexesBySignalQuality.begin() + connectedIdx + 1, vecIndexesBySignalQuality.end()); + + if (connectedIdx > 0) + { + auto iter = vecIndexesBySignalQuality.begin(); + std::rotate(iter, iter + connectedIdx, iter + connectedIdx + 1); + } } - /* TODO: remove networks that do not have a saved profile matching the SSID */ for (const auto& dwNetwork : vecIndexesBySignalQuality) { WLAN_AVAILABLE_NETWORK wlanNetwork = this->lstWlanNetworks->Network[dwNetwork];