source-engine/utils/vmpi/ThreadedTCPSocket.cpp

1086 lines
25 KiB
C++
Raw Permalink Normal View History

2020-04-22 12:56:21 -04:00
//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
// $NoKeywords: $
//
//=============================================================================//
// ThreadedTCPSocket.cpp : Defines the entry point for the console application.
//
#include <winsock2.h>
#include <mswsock.h>
#include "IThreadedTCPSocket.h"
#include "utllinkedlist.h"
#include "threadhelpers.h"
#include "iphelpers.h"
#include "tier1/strtools.h"
#define SEND_KEEPALIVE_INTERVAL 3000
#define KEEPALIVE_TIMEOUT 25000 // The sockets timeout after this long.
#define KEEPALIVE_SENTINEL -12345 // When first 4 bytes of a packet = this, then it's just a keepalive.
static int g_KeepaliveSentinel = KEEPALIVE_SENTINEL;
bool g_bHandleTimeouts = true;
// If true, it'll set the socket thread priorities lower than normal.
bool g_bSetTCPSocketThreadPriorities = true;
// We get crashes at runtime if they don't link in the multithreaded runtime libraries,
// so raise a ruckus if they're using singlethreaded libraries.
#ifndef _MT
#pragma message( "**** WARNING **** ThreadedTCPSocket requires multithreaded runtime libraries to be used.\n" )
class MTChecker
{
public:
MTChecker() { Assert( false ); }
} g_MTChecker;
#endif
// ------------------------------------------------------------------------------------------------ //
// Static helpers.
// ------------------------------------------------------------------------------------------------ //
static SOCKET TCPBind( const CIPAddr *pAddr )
{
// Create a socket to send and receive through.
SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED );
if ( sock == INVALID_SOCKET )
{
Assert( false );
return INVALID_SOCKET;
}
// bind to it!
sockaddr_in addr;
IPAddrToSockAddr( pAddr, &addr );
int status = bind( sock, (sockaddr*)&addr, sizeof(addr) );
if ( status == 0 )
{
return sock;
}
else
{
closesocket( sock );
return INVALID_SOCKET;
}
}
// ------------------------------------------------------------------------------------------------ //
// CTCPPacket.
// ------------------------------------------------------------------------------------------------ //
int CTCPPacket::GetUserData() const
{
return m_UserData;
}
void CTCPPacket::SetUserData( int userData )
{
m_UserData = userData;
}
void CTCPPacket::Release()
{
free( this );
}
// ------------------------------------------------------------------------------------------------ //
// CThreadedTCPSocket.
// ------------------------------------------------------------------------------------------------ //
class CThreadedTCPSocket : public IThreadedTCPSocket
{
public:
static IThreadedTCPSocket* Create( SOCKET iSocket, CIPAddr remoteAddr, ITCPSocketHandler *pHandler )
{
CThreadedTCPSocket *pRet = new CThreadedTCPSocket;
if ( pRet->Init( iSocket, remoteAddr, pHandler ) )
{
return pRet;
}
else
{
pRet->Release();
return NULL;
}
}
// IThreadedTCPSocket implementation.
public:
virtual void Release()
{
delete this;
}
virtual CIPAddr GetRemoteAddr() const
{
return m_RemoteAddr;
}
virtual bool IsValid()
{
return !CheckErrorSignal();
}
virtual bool Send( const void *pData, int len )
{
const void *pChunks[1] = { pData };
return SendChunks( pChunks, &len, 1 );
}
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
{
if ( CheckErrorSignal() )
return false;
return InternalSend( pChunks, pChunkLengths, nChunks, true );
}
// Initialization.
private:
CThreadedTCPSocket()
{
m_Socket = INVALID_SOCKET;
m_pHandler = NULL;
memset( &m_SendOverlapped, 0, sizeof( m_SendOverlapped ) );
memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) );
m_bWaitingForSendCompletion = false;
m_nBytesToReceive = -1;
m_bWaitingForSize = false;
m_bErrorSignal = false;
m_pRecvBuffer = NULL;
}
virtual ~CThreadedTCPSocket()
{
Term();
}
bool Init( SOCKET iSocket, CIPAddr remoteAddr, ITCPSocketHandler *pHandler )
{
m_Socket = iSocket;
m_RemoteAddr = remoteAddr;
m_pHandler = pHandler;
SetInitialSocketOptions();
// Create all the event objects we'll use to communicate.
m_hExitThreadsEvent.Init( true, false );
m_hSendCompletionEvent.Init( false, false );
m_hReadyToSendEvent.Init( false, false );
m_hRecvEvent.Init( false, false );
m_SendOverlapped.hEvent = m_hSendCompletionEvent.GetEventHandle();
m_RecvOverlapped.hEvent = m_hRecvEvent.GetEventHandle();
// Create our threads.
DWORD dwSendThreadID, dwRecvThreadID;
m_hSendThread = CreateThread( NULL, 0, &CThreadedTCPSocket::StaticSendThreadFn, this, CREATE_SUSPENDED, &dwSendThreadID );
m_hRecvThread = CreateThread( NULL, 0, &CThreadedTCPSocket::StaticRecvThreadFn, this, CREATE_SUSPENDED, &dwRecvThreadID );
if ( !m_hSendThread || !m_hRecvThread )
{
return false;
}
if ( g_bSetTCPSocketThreadPriorities )
{
SetThreadPriority( m_hSendThread, THREAD_PRIORITY_LOWEST );
SetThreadPriority( m_hRecvThread, THREAD_PRIORITY_LOWEST );
}
ThreadSetDebugName( (ThreadId_t)dwSendThreadID, "TCPSend" );
ThreadSetDebugName( (ThreadId_t)dwRecvThreadID, "TCPRecv" );
// Make sure to init the handler before the threads actually run, so it isn't handed data before initializing.
m_pHandler->Init( this );
ResumeThread( m_hSendThread );
ResumeThread( m_hRecvThread );
return true;
}
void Term()
{
// Signal our threads to exit.
m_hExitThreadsEvent.SetEvent();
if ( m_hSendThread )
{
WaitForSingleObject( m_hSendThread, INFINITE );
CloseHandle( m_hSendThread );
m_hSendThread = NULL;
}
if ( m_hRecvThread )
{
WaitForSingleObject( m_hRecvThread, INFINITE );
CloseHandle( m_hRecvThread );
m_hRecvThread = NULL;
}
m_hExitThreadsEvent.ResetEvent();
if ( m_Socket != INVALID_SOCKET )
{
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
}
}
// Set the initial socket options that we want.
void SetInitialSocketOptions()
{
// Set nodelay to improve latency.
BOOL val = TRUE;
setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) );
// Make it linger for 3 seconds when it exits.
LINGER linger;
linger.l_onoff = 1;
linger.l_linger = 3;
setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) );
}
// Send thread functionality.
private:
// This function copies off the payload and adds a SendChunk_t to the list of chunks to be sent.
// It also fires the ReadyToSend event so the thread will pick it up.
bool InternalSend( void const * const *pChunks, const int *pChunkLengths, int nChunks, bool bPrependLength )
{
int totalLength = 0;
for ( int i=0; i < nChunks; i++ )
totalLength += pChunkLengths[i];
if ( bPrependLength )
{
if ( totalLength == 0 )
return true;
totalLength += 4;
}
// Copy all the data into a SendData_t.
SendData_t *pSendData = (SendData_t*)malloc( sizeof( SendData_t ) - 1 + totalLength );
pSendData->m_Len = totalLength;
char *pOut = pSendData->m_Payload;
if ( bPrependLength )
{
*((int*)pOut) = totalLength - 4; // The length we prepend is the size of the data, not data size + integer for length.
pOut += 4;
}
for ( int i=0; i < nChunks; i++ )
{
memcpy( pOut, pChunks[i], pChunkLengths[i] );
pOut += pChunkLengths[i];
}
CCriticalSectionLock csLock( &m_SendCS );
csLock.Lock();
m_SendDatas.AddToTail( pSendData );
m_hReadyToSendEvent.SetEvent(); // Notify the thread that there is data to send.
csLock.Unlock();
return true;
}
void SendThread_HandleTimeout()
{
// Timeout.. send a keepalive.
// But only if we're not already sending something.
CCriticalSectionLock csLock( &m_SendCS );
csLock.Lock();
int count = m_SendDatas.Count();
csLock.Unlock();
if ( count == 0 )
{
void *pBuf[1] = { &g_KeepaliveSentinel };
int len[1] = { sizeof( g_KeepaliveSentinel ) };
InternalSend( pBuf, len, 1, false );
}
}
bool SendThread_HandleSendCompletionEvent()
{
Assert( m_bWaitingForSendCompletion );
m_bWaitingForSendCompletion = false;
// A send operation just completed. Now do the next one.
DWORD cbTransfer, flags;
if ( !WSAGetOverlappedResult( m_Socket, &m_SendOverlapped, &cbTransfer, TRUE, &flags ) )
{
HandleError( WSAGetLastError() );
return false;
}
if ( cbTransfer != m_nBytesToTransfer )
{
char str[512];
Q_snprintf( str, sizeof( str ), "Invalid # bytes transferred (%d) in send thread (should be %d)", cbTransfer, m_nBytesToTransfer );
HandleError( ITCPSocketHandler::SocketError, str );
return false;
}
// Remove the block we just sent.
CCriticalSectionLock csLock( &m_SendCS );
csLock.Lock();
SendData_t *pSendData = m_SendDatas[ m_SendDatas.Head() ];
free( pSendData );
m_SendDatas.Remove( m_SendDatas.Head() );
m_bWaitingForSendCompletion = false;
// Set our send event if there's anything else to send.
if ( m_SendDatas.Count() > 0 )
m_hReadyToSendEvent.SetEvent();
csLock.Unlock();
return true;
}
bool SendThread_HandleReadyToSendEvent()
{
// We've got at least one buffer that's ready to be sent.
// NOTE: don't send anything until our current send is completed.
CCriticalSectionLock csLock( &m_SendCS );
csLock.Lock();
Assert( !m_bWaitingForSendCompletion );
// Send it off!
SendData_t *pSendData = m_SendDatas[ m_SendDatas.Head() ];
WSABUF buf = { pSendData->m_Len, pSendData->m_Payload };
m_nBytesToTransfer = pSendData->m_Len;
m_bWaitingForSendCompletion = true;
csLock.Unlock();
DWORD dwNumBytesSent = 0;
DWORD ret = WSASend( m_Socket, &buf, 1, &dwNumBytesSent, 0, &m_SendOverlapped, NULL );
DWORD err = WSAGetLastError();
if ( ret == 0 || ( ret == SOCKET_ERROR && err == WSA_IO_PENDING ) )
{
// Either way, the operation completed successfully, and m_hSendCompletionEvent is now set.
return true;
}
else
{
HandleError( err );
return false;
}
return true;
}
DWORD SendThreadFn()
{
while ( 1 )
{
HANDLE handles[] =
{
m_hExitThreadsEvent.GetEventHandle(),
m_hSendCompletionEvent.GetEventHandle(),
m_hReadyToSendEvent.GetEventHandle()
};
int nHandles = ARRAYSIZE( handles );
// While waiting for send completion, don't handle "ready to send" events.
if ( m_bWaitingForSendCompletion )
--nHandles;
DWORD waitValue = WaitForMultipleObjects( nHandles, handles, FALSE, SEND_KEEPALIVE_INTERVAL );
switch ( waitValue )
{
case WAIT_TIMEOUT:
{
if ( g_bHandleTimeouts )
{
// We haven't sent anything in a bit. Send out a keepalive.
SendThread_HandleTimeout();
}
}
break;
case WAIT_OBJECT_0:
{
// The main thread is signaling us to exit.
return 0;
}
case WAIT_OBJECT_0 + 1:
{
if ( !SendThread_HandleSendCompletionEvent() )
return 1;
}
break;
case WAIT_OBJECT_0 + 2:
{
if ( !SendThread_HandleReadyToSendEvent() )
return 1;
}
break;
case WAIT_FAILED:
{
// Uh oh. We're dead. Cleanup and signal an error.
HandleError( GetLastError() );
return 1;
}
default:
{
char str[512];
Q_snprintf( str, sizeof( str ), "Unknown return value (%lu) from WaitForMultipleObjects", waitValue );
HandleError( ITCPSocketHandler::SocketError, str );
return 0;
}
}
}
return 0;
}
static DWORD WINAPI StaticSendThreadFn( LPVOID pParameter )
{
return ((CThreadedTCPSocket*)pParameter)->SendThreadFn();
}
// Receive thread functionality.
private:
bool RecvThread_WaitToReceiveSize()
{
return RecvThread_InternalRecv( &m_NextPacketLen, sizeof( m_NextPacketLen ), false, true );
}
bool RecvThread_InternalHandleRecvCompletion( DWORD dwTransfer )
{
int cbTransfer = (int)dwTransfer;
int nBytesWanted = m_nBytesToReceive - m_nBytesReceivedSoFar;
if ( cbTransfer > nBytesWanted )
{
char str[512];
Q_snprintf( str, sizeof( str ), "Invalid # bytes received (%d) in recv thread (should be %d)", cbTransfer, m_nBytesToReceive );
HandleError( ITCPSocketHandler::SocketError, str );
return false;
}
else if ( cbTransfer < nBytesWanted )
{
// We have to reissue the receive command because it didn't receive all the data.
m_nBytesReceivedSoFar += cbTransfer;
char *pDest = (char*)&m_NextPacketLen;
if ( !m_bWaitingForSize )
{
Assert( m_pRecvBuffer );
pDest = m_pRecvBuffer->m_Data;
}
return RecvThread_InternalRecv( &pDest[m_nBytesReceivedSoFar], m_nBytesToReceive - m_nBytesReceivedSoFar, true );
}
if ( m_bWaitingForSize )
{
// If we were waiting for size, now wait for the data.
if ( m_NextPacketLen == KEEPALIVE_SENTINEL )
{
// Ok, it was just a keepalive. Wait for size again.
return RecvThread_WaitToReceiveSize();
}
else
{
if ( m_NextPacketLen < 1 || m_NextPacketLen > 1024*1024*75 )
{
char str[512];
Q_snprintf( str, sizeof( str ), "Invalid packet size in RecvThread (size = %d)", m_NextPacketLen );
HandleError( ITCPSocketHandler::SocketError, str );
return false;
}
else
{
Assert( !m_pRecvBuffer );
m_pRecvBuffer = (CTCPPacket*)malloc( sizeof( CTCPPacket ) - 1 + m_NextPacketLen );
m_pRecvBuffer->m_UserData = 0;
m_pRecvBuffer->m_Len = m_NextPacketLen;
return RecvThread_InternalRecv( m_pRecvBuffer->m_Data, m_pRecvBuffer->m_Len, false, false );
}
}
}
else
{
// Got a packet! Give it to the app.
m_pHandler->OnPacketReceived( m_pRecvBuffer );
m_pRecvBuffer = NULL;
return RecvThread_WaitToReceiveSize();
}
}
bool RecvThread_HandleRecvCompletionEvent()
{
// A send operation just completed. Now do the next one.
DWORD cbTransfer, flags;
if ( !WSAGetOverlappedResult( m_Socket, &m_RecvOverlapped, &cbTransfer, TRUE, &flags ) )
{
HandleError( WSAGetLastError() );
return false;
}
return RecvThread_InternalHandleRecvCompletion( cbTransfer );
}
bool RecvThread_InternalRecv( void *pDest, int destSize, bool bContinuation, bool bWaitingForSize = false )
{
WSABUF buf = { destSize, (char*)pDest };
if ( !bContinuation )
{
// If this is not a continuation of whatever we were receiving before, then
m_bWaitingForSize = bWaitingForSize;
m_nBytesToReceive = destSize;
m_nBytesReceivedSoFar = 0;
}
DWORD dwFlags = 0;
DWORD nBytesReceived = 0;
DWORD ret = WSARecv( m_Socket, &buf, 1, &nBytesReceived, &dwFlags, &m_RecvOverlapped, NULL );
DWORD dwLastError = WSAGetLastError();
if ( ret == 0 || ( ret == SOCKET_ERROR && dwLastError == WSA_IO_PENDING ) )
{
// Note: m_hRecvEvent is in a signaled state, so the RecvThread will pick up the results next time around.
return true;
}
else
{
HandleError( dwLastError );
return false;
}
}
DWORD RecvThreadFn()
{
// Start us off by setting up to receive the first packet size.
if ( !RecvThread_WaitToReceiveSize() )
return 1;
HANDLE handles[] =
{
m_hExitThreadsEvent.GetEventHandle(),
m_hRecvEvent.GetEventHandle()
};
while ( 1 )
{
DWORD waitValue = WaitForMultipleObjects( ARRAYSIZE( handles ), handles, FALSE, KEEPALIVE_TIMEOUT );
switch ( waitValue )
{
case WAIT_TIMEOUT:
{
if ( g_bHandleTimeouts )
{
HandleError( ITCPSocketHandler::ConnectionTimedOut, "Connection timed out" );
return 1;
}
}
break;
case WAIT_OBJECT_0:
{
// We're being told to exit.
return 0;
}
case WAIT_OBJECT_0 + 1:
{
// Just finished receiving something.
if ( !RecvThread_HandleRecvCompletionEvent() )
return 1;
}
break;
case WAIT_FAILED:
{
// Uh oh. We're dead. Cleanup and signal an error.
HandleError( GetLastError() );
return 1;
}
default:
{
char str[512];
Q_snprintf( str, sizeof( str ), "Unknown return value (%lu) from WaitForMultipleObjects", waitValue );
HandleError( ITCPSocketHandler::SocketError, str );
return 1;
}
}
}
return 0;
}
static DWORD WINAPI StaticRecvThreadFn( LPVOID pParameter )
{
return ((CThreadedTCPSocket*)pParameter)->RecvThreadFn();
}
// Error handling.
private:
// This checks to see if either thread has signaled an error. If so, it shuts down the socket and returns true.
bool CheckErrorSignal()
{
return m_bErrorSignal;
}
// This is called from any of the threads and signals that something went awry. It shuts down the object
// and makes it return false from all of its functions.
void HandleError( DWORD errorValue )
{
char *lpMsgBuf;
FormatMessage(
FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL,
GetLastError(),
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
(char*)&lpMsgBuf,
0,
NULL
);
// Windows likes to stick a carriage return in there and we don't want it so get rid of it.
int len = strlen( lpMsgBuf );
while ( len > 0 && ( lpMsgBuf[len-1] == '\n' || lpMsgBuf[len-1] == '\r' ) )
{
--len;
lpMsgBuf[len] = 0;
}
HandleError( ITCPSocketHandler::SocketError, lpMsgBuf );
LocalFree( lpMsgBuf );
}
void HandleError( int errorCode, const char *pErrorString )
{
//Assert( false );
// Tell the app.
m_pHandler->OnError( errorCode, pErrorString );
// Tell the threads to exit.
m_hExitThreadsEvent.SetEvent();
// Notify the main thread so it can call Term() when it gets a chance.
m_bErrorSignal = true;
}
private:
// Data for the send thread.
typedef struct
{
//WSAOVERLAPPED m_Overlapped;
int m_Len;
char m_Payload[1];
} SendData_t;
HANDLE m_hSendThread;
WSAOVERLAPPED m_SendOverlapped;
CEvent m_hReadyToSendEvent;
CEvent m_hSendCompletionEvent;
CCriticalSection m_SendCS;
DWORD m_nBytesToTransfer;
bool m_bWaitingForSendCompletion;
CUtlLinkedList<SendData_t*, int> m_SendDatas; // Added to the tail, popped off the head for sending.
// Data for the recv thread.
HANDLE m_hRecvThread;
int m_nBytesToReceive; // This stores how many bytes we want to receive for the next packet.
int m_nBytesReceivedSoFar; // This stores how many bytes we've received so far.
bool m_bWaitingForSize; // This tells if we're trying to receive the next packet length or the next packet's data.
int m_NextPacketLen; // Data is received INTO here before it receives each packet. This
// holds the length of each incoming packet.
WSAOVERLAPPED m_RecvOverlapped;
CEvent m_hRecvEvent;
CTCPPacket *m_pRecvBuffer; // This is allocated for each packet we're receiving and given to the
// app when the packet is done being received.
volatile bool m_bErrorSignal;
CEvent m_hExitThreadsEvent;
ITCPSocketHandler *m_pHandler;
SOCKET m_Socket;
CIPAddr m_RemoteAddr;
};
// ------------------------------------------------------------------------------------------------ //
// CTCPConnectSocket_Listener
// ------------------------------------------------------------------------------------------------ //
class CTCPConnectSocket_Listener : public ITCPConnectSocket
{
public:
CTCPConnectSocket_Listener()
{
m_Socket = INVALID_SOCKET;
}
virtual ~CTCPConnectSocket_Listener()
{
if ( m_Socket != INVALID_SOCKET )
{
closesocket( m_Socket );
}
}
// The main function to create one of these suckers.
static ITCPConnectSocket* Create(
IHandlerCreator *pHandlerCreator,
const unsigned short port,
int nQueueLength
)
{
CTCPConnectSocket_Listener *pRet = new CTCPConnectSocket_Listener;
if ( !pRet )
return NULL;
if ( nQueueLength < 0 )
{
Error( "CTCPConnectSocket_Listener::Create - SOMAXCONN not allowed - causes some XP SP2 systems to stop receiving any network data (systemwide)." );
}
// Bind it to a socket and start listening.
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
pRet->m_Socket = TCPBind( &addr );
if ( pRet->m_Socket == INVALID_SOCKET ||
listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 )
{
pRet->Release();
return false;
}
pRet->m_pHandler = pHandlerCreator;
return pRet;
}
// ITCPConnectSocket implementation.
public:
virtual void Release()
{
delete this;
}
virtual bool Update( IThreadedTCPSocket **pSocket, unsigned long milliseconds )
{
*pSocket = NULL;
if ( m_Socket == INVALID_SOCKET )
return false;
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
fd_set readSet;
readSet.fd_count = 1;
readSet.fd_array[0] = m_Socket;
TIMEVAL timeVal = {0, milliseconds*1000};
// Wait until it connects.
int status = select( 0, &readSet, NULL, NULL, &timeVal );
if ( status > 0 )
{
sockaddr_in addr;
int addrSize = sizeof( addr );
// Now accept the final connection.
SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize );
if ( newSock == INVALID_SOCKET )
{
Assert( false );
return true;
}
else
{
CIPAddr connectedAddr;
SockAddrToIPAddr( &addr, &connectedAddr );
IThreadedTCPSocket *pRet = CThreadedTCPSocket::Create( newSock, connectedAddr, m_pHandler->CreateNewHandler() );
if ( !pRet )
{
Assert( false );
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
return false;
}
*pSocket = pRet;
return true;
}
}
else if ( status == SOCKET_ERROR )
{
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
return false;
}
else
{
return true;
}
}
private:
SOCKET m_Socket;
IHandlerCreator *m_pHandler;
};
ITCPConnectSocket* ThreadedTCP_CreateListener(
IHandlerCreator *pHandlerCreator,
const unsigned short port,
int nQueueLength
)
{
return CTCPConnectSocket_Listener::Create( pHandlerCreator, port, nQueueLength );
}
// ------------------------------------------------------------------------------------------------ //
// CTCPConnectSocket_Connector
// ------------------------------------------------------------------------------------------------ //
class CTCPConnectSocket_Connector : public ITCPConnectSocket
{
public:
CTCPConnectSocket_Connector()
{
m_bConnected = false;
m_Socket = INVALID_SOCKET;
m_bError = false;
}
virtual ~CTCPConnectSocket_Connector()
{
if ( m_Socket != INVALID_SOCKET )
{
closesocket( m_Socket );
}
}
static ITCPConnectSocket* Create(
const CIPAddr &connectAddr,
const CIPAddr &localAddr,
IHandlerCreator *pHandlerCreator
)
{
CTCPConnectSocket_Connector *pRet = new CTCPConnectSocket_Connector;
pRet->m_Socket = TCPBind( &localAddr );
if ( pRet->m_Socket == INVALID_SOCKET )
{
pRet->Release();
return NULL;
}
sockaddr_in addr;
IPAddrToSockAddr( &connectAddr, &addr );
// We don't want the connect() call to block.
DWORD val = 1;
int status = ioctlsocket( pRet->m_Socket, FIONBIO, &val );
if ( status != 0 )
{
Assert( false );
pRet->Release();
return NULL;
}
pRet->m_RemoteAddr = connectAddr;
pRet->m_pHandlerCreator = pHandlerCreator;
int ret = connect( pRet->m_Socket, (struct sockaddr*)&addr, sizeof( addr ) );
if ( ret == 0 )
{
pRet->m_bConnected = true;
return pRet;
}
else if ( ret == SOCKET_ERROR && WSAGetLastError() == WSAEWOULDBLOCK )
{
return pRet;
}
else
{
Assert( false );
pRet->Release();
return NULL;
}
}
// ITCPConnectSocket implementation.
public:
virtual void Release()
{
delete this;
}
virtual bool Update( IThreadedTCPSocket **pSocket, unsigned long milliseconds )
{
*pSocket = NULL;
// If we got an error previously, keep returning false.
if ( m_bError )
return false;
// If this condition holds, then we already returned a valid socket and we're just waiting to be released.
if ( m_Socket == INVALID_SOCKET )
return true;
// Ok, see if we're connected now.
if ( !m_bConnected )
{
TIMEVAL timeVal = { 0, milliseconds*1000 };
fd_set writeSet;
writeSet.fd_count = 1;
writeSet.fd_array[0] = m_Socket;
int ret = select( 0, NULL, &writeSet, NULL, &timeVal );
if ( ret > 0 )
{
m_bConnected = true;
}
else if ( ret == SOCKET_ERROR )
{
return EnterErrorMode();
}
}
if ( m_bConnected )
{
// Ok, return a connected socket for them.
// Make our socket blocking again.
DWORD val = 0;
int status = ioctlsocket( m_Socket, FIONBIO, &val );
if ( status != 0 )
{
Assert( false );
m_bError = true;
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
return false;
}
IThreadedTCPSocket *pRet = CThreadedTCPSocket::Create( m_Socket, m_RemoteAddr, m_pHandlerCreator->CreateNewHandler() );
if ( pRet )
{
m_Socket = INVALID_SOCKET;
*pSocket = pRet;
return true;
}
else
{
return EnterErrorMode();
}
}
else
{
// Still waiting..
return true;
}
}
// Shutdown the socket and start returning false from Update().
bool EnterErrorMode()
{
Assert( false );
m_bError = true;
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
return false;
}
private:
bool m_bError;
bool m_bConnected;
SOCKET m_Socket;
CIPAddr m_RemoteAddr;
IHandlerCreator *m_pHandlerCreator;
};
ITCPConnectSocket* ThreadedTCP_CreateConnector(
const CIPAddr &addr,
const CIPAddr &localAddr,
IHandlerCreator *pHandlerCreator
)
{
return CTCPConnectSocket_Connector::Create( addr, localAddr, pHandlerCreator );
}
void ThreadedTCP_EnableTimeouts( bool bEnable )
{
g_bHandleTimeouts = bEnable;
}
void ThreadedTCP_SetTCPSocketThreadPriorities( bool bSetTCPSocketThreadPriorities )
{
g_bSetTCPSocketThreadPriorities = bSetTCPSocketThreadPriorities;
}