free web page counters

Windows Mobile Pocket PC Smartphone Programming

==>Click here for the SiteMap<==. Original contents with decent amount of source codes.

Saturday, March 25, 2006

Windows Mobile Secure Socket Implementation Series 2: Sample Source Code

====>SiteMap of this Blog<===

Windows Mobile Secure Socket Implementation Series 2: Sample Source Code

In previous post, I talked about general ideas to write secure socket applications in Windows Mobile device (or actually on Windows CE platforms). The current post has the source code and simple explanations on the code.



Function: Establish a secure socket connection to a remote host:port

Notice the remote host name is passed as certificate validation function's pvArg, which will be given back to us by Winsock when the validation function is called.


int SecureConnect(int in_socket,
   const struct sockaddr * in_pHostInetAddr,
   int in_port )
{
   struct sockaddr_in theInetAddr;
   int sockerror = 0;

   /* initialize the address structures */
   memset (&theInetAddr, 0, sizeof(theInetAddr));

   theInetAddr.sin_family = AF_INET;
   theInetAddr.sin_addr.s_addr = htonl(INADDR_ANY);
   theInetAddr.sin_port = htons(in_port);

   // make it secure
   {
      DWORD dwOptVal = SO_SEC_SSL;
      DWORD dwBytes = 0;
      SSLVALIDATECERTHOOK sslValidateFunc;

      sockerror = setsockopt(in_socket, SOL_SOCKET,
         SO_SECURE, (LPSTR)&dwOptVal, sizeof(dwOptVal));

      if (SOCKET_ERROR == sockerror){
         // error logging
         return 0;
      }

      // register certificate validation callback
      sslValidateFunc.HookFunc = certificateValidationCallback;
      sslValidateFunc.pvArg = "www.sample.com"; // ... passing server name from your context

      sockerror = WSAIoctl(in_socket, SO_SSL_SET_VALIDATE_CERT_HOOK,
         &sslValidateFunc, sizeof(sslValidateFunc), NULL, 0, &dwBytes, NULL, NULL);

      if (SOCKET_ERROR == sockerror){
         // error logging
         return 0;
      }
   }


   // connect
   sockerror = connect(in_socket,
      (struct sockaddr *)in_pHostInetAddr, sizeof (*in_pHostInetAddr));
   if (sockerror == SOCKET_ERROR) {
      // error logging
      return 0;
   }

   {
      // show SSLCONNECTIONINFO
      SSLCONNECTIONINFO SSLConnectionInfo;
      DWORD dwBytes = 0;
      sockerror = WSAIoctl(in_socket, SO_SSL_GET_CONNECTION_INFO,
         NULL, 0, &SSLConnectionInfo, sizeof(SSLConnectionInfo), &dwBytes, NULL,NULL);
      if (sockerror == SOCKET_ERROR) {
         // error logging
         return 0;
      }
      ShowConnectionInfo(&SSLConnectionInfo);
   }

   return 1;
}


Function: Load the functions from schannel.dll, to crack a BLOB into X.509 certificate

The two functions pointers are defined as global variables so that they do not need to be dynamically loaded or unloaded during each SSL handshake phase.



#include <wincrypt.h>
#include <schnlsp.h>

// load SslCrackCertificate and SslFreeCertificate
#define SSL_CRACK_CERTIFICATE_NAME TEXT("SslCrackCertificate")
#define SSL_FREE_CERTIFICATE_NAME TEXT("SslFreeCertificate")

HRESULT LoadSSL()
{
   // already loaded?
   if (gSslCrackCertificate && gSslFreeCertificate) return S_OK;

   hSchannelDLL = LoadLibrary(TEXT("schannel.dll"));
   if (!hSchannelDLL) {
      // error logging
      return E_FAIL;
   }

   gSslCrackCertificate = (SSL_CRACK_CERTIFICATE_FN)GetProcAddress(hSchannelDLL, SSL_CRACK_CERTIFICATE_NAME);
   gSslFreeCertificate = (SSL_FREE_CERTIFICATE_FN)GetProcAddress(hSchannelDLL, SSL_FREE_CERTIFICATE_NAME);

   if (!gSslCrackCertificate || !gSslFreeCertificate) {
      // error logging
      gSslCrackCertificate = NULL;
      gSslFreeCertificate = NULL;
      FreeLibrary(hSchannelDLL);
      hSchannelDLL = NULL;
      return E_FAIL;
   } else {
      return S_OK;
   }
}

HRESULT FreeSSL()
{
   if (hSchannelDLL) {
      FreeLibrary(hSchannelDLL);
      hSchannelDLL = NULL;
   }
   return S_OK;
}


Function: Show SSL connection information for debugging or learning purpose

Do not put it into your release build.



void ShowConnectionInfo(SSLCONNECTIONINFO *pConnectionInfo)
{
   TCHAR szTemp[1028];
   memset(szTemp, 0, 1028*sizeof(TCHAR));

   switch(pConnectionInfo->dwProtocol)
   {
   case SSL_PROTOCOL_SSL3:
   swprintf(szTemp+wcslen(szTemp), TEXT("Protocol: SSL3"));
   break;

   case SSL_PROTOCOL_PCT1:
   swprintf(szTemp+wcslen(szTemp), TEXT("Protocol: PCT"));
   break;

   case SSL_PROTOCOL_SSL2:
   swprintf(szTemp+wcslen(szTemp), TEXT("Protocol: SSL2"));
   break;

   default:
   swprintf(szTemp+wcslen(szTemp), TEXT("Protocol: 0x%08x"), pConnectionInfo->dwProtocol);
   }
   wcscat(szTemp, TEXT("\n"));

   switch(pConnectionInfo->aiCipher)
   {
   case CALG_RC4:
   swprintf(szTemp+wcslen(szTemp), TEXT("Cipher: RC4"));
   break;

   case CALG_3DES:
   swprintf(szTemp+wcslen(szTemp), TEXT("Cipher: Triple DES"));
   break;

   case CALG_RC2:
   swprintf(szTemp+wcslen(szTemp), TEXT("Cipher: RC2"));
   break;

   case CALG_DES:
   swprintf(szTemp+wcslen(szTemp), TEXT("Cipher: DES"));
   break;

   case CALG_SKIPJACK:
   swprintf(szTemp+wcslen(szTemp), TEXT("Cipher: Skipjack"));
   break;

   default:
   swprintf(szTemp+wcslen(szTemp), szTemp, TEXT("Cipher: 0x%08x"), pConnectionInfo->aiCipher);
   }
   wcscat(szTemp, TEXT("\n"));

   swprintf(szTemp+wcslen(szTemp), TEXT("Cipher strength: %d"), pConnectionInfo->dwCipherStrength);
   wcscat(szTemp, TEXT("\n"));

   switch(pConnectionInfo->aiHash)
   {
   case CALG_MD5:
   swprintf(szTemp+wcslen(szTemp), TEXT("Hash: MD5"));
   break;

   case CALG_SHA:
   swprintf(szTemp+wcslen(szTemp), TEXT("Hash: SHA"));
   break;

   default:
   swprintf(szTemp+wcslen(szTemp), szTemp, TEXT("Hash: 0x%08x"), pConnectionInfo->aiHash);
   }
   wcscat(szTemp, TEXT("\n"));

   swprintf(szTemp+wcslen(szTemp), TEXT("Hash strength: %d"), pConnectionInfo->dwHashStrength);
   wcscat(szTemp, TEXT("\n"));

   switch(pConnectionInfo->aiExch)
   {
   case CALG_RSA_KEYX:
   case CALG_RSA_SIGN:
   swprintf(szTemp+wcslen(szTemp), TEXT("Key exchange: RSA"));
   break;

   case CALG_KEA_KEYX:
   swprintf(szTemp+wcslen(szTemp), TEXT("Key exchange: KEA"));
   break;

   default:
   swprintf(szTemp+wcslen(szTemp), TEXT("Key exchange: 0x%08x"), pConnectionInfo->aiExch);
   }
   wcscat(szTemp, TEXT("\n"));

   swprintf(szTemp+wcslen(szTemp), TEXT("Key exchange strength: %d"), pConnectionInfo->dwExchStrength);
   MessageBox(NULL, szTemp, NULL, NULL);
}


Function: Certificate validation callback

Notice the order how it checks the certificate and other parameters. Sometimes Winsock passes false alarms in dwFlags. The way how Winsock internally handles the certificate in Windows Mobile 5.0 devices (Windows CE 5.0 based) is also changed from Windows Mobile 2003 devices (CE 4.2 based). This will be discussed in my later post.



// the certificate validattion for SSL
int certificateValidationCallback(
   DWORD dwType,
   LPVOID pvArg,
   DWORD dwChainLen,
   LPBLOB pCertChain,
   DWORD dwFlags)
{
   X509Certificate* pCert = NULL;
   int nRet = SSL_ERR_CERT_UNKNOWN;

   // dwType must be SSL_CERT_X.509
   if (dwType != SSL_CERT_X509) {
      // error logging
      return nRet;
   }

   if (dwFlags & SSL_CERT_FLAG_ISSUER_UNKNOWN) {
      // error logging
      return nRet;
   }

   if (pCertChain == NULL) return nRet;
   ASSERT(dwChainLen == 1);

   if (!gSslCrackCertificate || !gSslFreeCertificate) {
      // error logging
      return nRet; // unable to crack
   }

   // crack X.509 Certificate
   if (!gSslCrackCertificate(pCertChain->pBlobData, pCertChain->cbSize, TRUE, &pCert)) {
      // error logging
      return SSL_ERR_BAD_DATA;
   }

   // Site check
   {
      char* pchSubject = NULL;
      char* pchCN = NULL;
      BOOL bMatched = FALSE;
      char* pchRemoteHost = (char*)pvArg;

      pchSubject = pCert->pszSubject;

      // here you need to parse the subjec to retrieve the CN name
      pchCN = ParseCN(pchSubject);
      if (!pchCN) {
         goto FuncExit;
      }

      // CN comparison
      bMatched = !(_stricmp(pchRemoteHost, pchCN));
      if (!bMatched) {
         // error logging
         goto FuncExit;
      }
   }

   // validFrom, validUntil check
   {
      SYSTEMTIME stNow;
      FILETIME ftNow;
      FILETIME ftValidFrom = pCert->ValidFrom;
      FILETIME ftValidUntil = pCert->ValidUntil;

      GetSystemTime(&stNow);
      SystemTimeToFileTime(&stNow, &ftNow);

      if (!(IsEarlierThan(&ftValidFrom, &ftNow) && IsEarlierThan(&ftNow, &ftValidUntil))) {
         // give user an option to continue or not
         // a little more lenient than Subject check
      }
   }

   nRet = SSL_ERR_OKAY;

FuncExit:
   gSslFreeCertificate(pCert);

   return nRet;
}


Category: [SSL / Socket / Networking / Connection Manager]

====>SiteMap of this Blog<===