#ifndef __ATLTHRDX_H__
#define __ATLTHRDX_H__

#pragma once

/////////////////////////////////////////////////////////////////////////////
// Windows Vista Thread Pool
//
// Written by Bjarke Viksoe (bjarke@viksoe.dk)
// Copyright (c) 2008 Bjarke Viksoe.
//
// This code may be used in compiled form in any way you desire. This
// source file may be redistributed by any means PROVIDING it is 
// not sold for profit without the authors written consent, and 
// providing that this notice and the authors name is included. 
//
// This file is provided "as is" with no expressed or implied warranty.
// The author accepts no liability if it causes any damage to you or your
// computer whatsoever. It's free, so don't hassle me about it.
//
// Beware of bugs.
//

#if _WIN32_WINNT >= 0x0600


class CThreadPoolTaskBase
{
public:
   PTP_CALLBACK_INSTANCE m_pCallbackInstance;
   HANDLE m_hReleaseEvent;

   CThreadPoolTaskBase() : m_pCallbackInstance(NULL), m_hReleaseEvent(NULL)
   {
   }

#ifndef _ATL_THREADPOOL_NO_AUTODELETE

   // Needed by CThreadpoolAutoDeleteCleanupGroup for "delete" propagation
   virtual ~CThreadPoolTaskBase()
   {
   }

#endif // _ATL_THREADPOOL_NO_AUTODELETE

   void InitTaskBase(PTP_CALLBACK_INSTANCE pCallbackInstance)
   {
      m_pCallbackInstance = pCallbackInstance;
      if( m_hReleaseEvent != NULL ) ::SetEventWhenCallbackReturns(m_pCallbackInstance, m_hReleaseEvent);
   }

   BOOL IsTaskLongRunning() const
   {
      _ASSERTE(m_pCallbackInstance);
      return ::CallbackMayRunLong(m_pCallbackInstance);
   }

   void DisassociateCurrentThreadFromTask()
   {
      _ASSERTE(m_pCallbackInstance);
      ::DisassociateCurrentThreadFromCallback(m_pCallbackInstance);
   }

   void SetEventWhenTaskReturns(HANDLE hEvent)
   {
      _ASSERTE(hEvent!=INVALID_HANDLE_VALUE);
      m_hReleaseEvent = hEvent;
      if( m_pCallbackInstance != NULL ) ::SetEventWhenCallbackReturns(m_pCallbackInstance, hEvent);
   }

   void ReleaseSemaphoreWhenTaskReturns(HANDLE hSemaphore, DWORD cRel = 1)
   {
      _ASSERTE(m_pCallbackInstance);
      _ASSERTE(hSemaphore!=INVALID_HANDLE_VALUE);
      ::ReleaseSemaphoreWhenCallbackReturns(m_pCallbackInstance, hSemaphore, cRel);
   }

   void LeaveCriticalSectionWhenTaskReturns(PCRITICAL_SECTION pCS)
   {
      _ASSERTE(m_pCallbackInstance);
      _ASSERTE(pCS);
      ::LeaveCriticalSectionWhenCallbackReturns(m_pCallbackInstance, pCS);
   }

   void ReleaseMutexWhenCallbackReturns(HANDLE hMutex)
   {
      _ASSERTE(m_pCallbackInstance);
      _ASSERTE(hMutex!=INVALID_HANDLE_VALUE);
      ::ReleaseMutexWhenCallbackReturns(m_pCallbackInstance, hMutex);
   }

   void FreeLibraryWhenTaskReturns(HMODULE hModule)
   {
      _ASSERTE(m_pCallbackInstance);
      ::FreeLibraryWhenCallbackReturns(m_pCallbackInstance, hModule);
   }
};


template< typename T >
class CThreadpoolWorkerTask : public CThreadPoolTaskBase
{
public:
   PTP_WORK m_pWorker;
   bool m_bCloseWorker;

   CThreadpoolWorkerTask() : m_pWorker(NULL), m_bCloseWorker(false)
   {
   }

   ~CThreadpoolWorkerTask()
   {
      CloseTask();      
   }

   bool AddToThreadpool(PTP_CALLBACK_ENVIRON pEnviron, PTP_CLEANUP_GROUP pCleanupGroup) throw()
   {
      if( m_pWorker == NULL ) m_pWorker = ::CreateThreadpoolWork(TaskCallback, static_cast<CThreadPoolTaskBase*>(this), pEnviron);
      if( m_pWorker == NULL ) return false;
      m_bCloseWorker = (pCleanupGroup == NULL);
      ::SubmitThreadpoolWork(m_pWorker);
      return true;
   }

   void CloseTask() throw()
   {
      if( m_pWorker != NULL && m_bCloseWorker ) ::CloseThreadpoolWork(m_pWorker);
      m_pWorker = NULL;
   }

   void WaitForTask(BOOL bCancelPending = FALSE) throw()
   {
      _ASSERTE(m_pWorker);
      ::WaitForThreadpoolWorkCallbacks(m_pWorker, bCancelPending);
   }

   operator PTP_WORK()
   {
      return m_pWorker;
   }

   static VOID CALLBACK TaskCallback(PTP_CALLBACK_INSTANCE pInstance, PVOID pParam, PTP_WORK /*pWork*/)
   {
      CThreadPoolTaskBase* pBase = reinterpret_cast<CThreadPoolTaskBase*>(pParam); pBase;
      T* pT = static_cast<T*>(pBase); pT;
      pT->InitTaskBase(pInstance);
      pT->Run();
   }

   void Run()
   {
      // TODO: Must override this!
   }
};


template< typename T >
class CThreadpoolWaitTask : public CThreadPoolTaskBase
{
public:
   PTP_WAIT m_pWorker;
   HANDLE m_hCallbackEvent;
   FILETIME m_ftCallbackTimeout;
   bool m_bUseFiletime;
   bool m_bCloseWorker;

   CThreadpoolWaitTask() : m_pWorker(NULL), m_bCloseWorker(false), m_hCallbackEvent(INVALID_HANDLE_VALUE), m_bUseFiletime(false)
   {
   }

   ~CThreadpoolWaitTask()
   {
      CloseTask();
   }

   bool AddToThreadpool(PTP_CALLBACK_ENVIRON pEnviron, PTP_CLEANUP_GROUP pCleanupGroup) throw()
   {
      _ASSERTE(m_hCallbackEvent!=INVALID_HANDLE_VALUE);
      if( m_pWorker == NULL ) m_pWorker = ::CreateThreadpoolWait(TaskCallback, static_cast<CThreadPoolTaskBase*>(this), pEnviron);
      if( m_pWorker == NULL ) return false;
      m_bCloseWorker = (pCleanupGroup == NULL);
      ::SetThreadpoolWait(m_pWorker, m_hCallbackEvent, m_bUseFiletime ? &m_ftCallbackTimeout : NULL);
      return true;
   }

   void SetTaskWaitInfo(HANDLE hEvent) throw()
   {
      _ASSERTE(hEvent!=INVALID_HANDLE_VALUE);
      m_hCallbackEvent = hEvent;
      m_bUseFiletime = false;
      if( m_pWorker != NULL ) ::SetThreadpoolWait(m_pWorker, hEvent, NULL);
   }

   void SetTaskWaitInfo(HANDLE hEvent, PFILETIME pftTimeout) throw()
   {
      _ASSERTE(hEvent!=INVALID_HANDLE_VALUE);
      _ASSERTE(pftTimeout);
      m_hCallbackEvent = hEvent;
      m_bUseFiletime = true;
      m_ftCallbackTimeout = *pftTimeout;
      if( m_pWorker != NULL ) ::SetThreadpoolWait(m_pWorker, hEvent, pftTimeout);
   }

   void SetTaskWaitInfo(HANDLE hEvent, DWORD dwTimeoutMS) throw()
   {
      _ASSERTE(hEvent!=INVALID_HANDLE_VALUE);
      m_hCallbackEvent = hEvent;
      m_bUseFiletime = true;
      ULONGLONG qwResult = (ULONGLONG) -((__int64) dwTimeoutMS) * 10000I64;
      m_ftCallbackTimeout.dwLowDateTime  = (DWORD) (qwResult & 0xFFFFFFFF);
      m_ftCallbackTimeout.dwHighDateTime = (DWORD) (qwResult >> 32);
      if( m_pWorker != NULL ) ::SetThreadpoolWait(m_pWorker, m_hCallbackEvent, &m_ftCallbackTimeout);
   }

   void CloseTask() throw()
   {
      if( m_pWorker != NULL && m_bCloseWorker ) ::CloseThreadpoolWait(m_pWorker);
      m_pWorker = NULL;
   }

   void WaitForTask(BOOL bCancelPending = FALSE) throw()
   {
      _ASSERTE(m_pWorker);
      ::WaitForThreadpoolWaitCallbacks(m_pWorker, bCancelPending);
   }

   operator PTP_WAIT()
   {
      return m_pWorker;
   }

   static VOID CALLBACK TaskCallback(PTP_CALLBACK_INSTANCE pInstance, PVOID pParam, PTP_WAIT /*pWait*/, TP_WAIT_RESULT WaitResult)
   {
      CThreadPoolTaskBase* pBase = reinterpret_cast<CThreadPoolTaskBase*>(pParam); pBase;
      T* pT = static_cast<T*>(pBase); pT;
      pT->InitTaskBase(pInstance);
      pT->Run(WaitResult);
   }

   void Run(TP_WAIT_RESULT WaitResult)
   {
      // TODO: Must override this!
   }
};


template< typename T >
class CThreadpoolTimerTask : public CThreadPoolTaskBase
{
public:
   PTP_TIMER m_pWorker;
   DWORD m_dwTaskPeriod;
   DWORD m_dwTaskWindowLength;
   FILETIME m_ftCallbackTimeout;
   bool m_bCloseWorker;

   CThreadpoolTimerTask() : m_pWorker(NULL), m_bCloseWorker(false), m_dwTaskPeriod(0), m_dwTaskWindowLength(0)
   {
      m_ftCallbackTimeout.dwLowDateTime = 
      m_ftCallbackTimeout.dwHighDateTime = 0;
   }

   ~CThreadpoolTimerTask()
   {
      CloseTask();
   }

   bool AddToThreadpool(PTP_CALLBACK_ENVIRON pEnviron, PTP_CLEANUP_GROUP pCleanupGroup) throw()
   {
      if( m_pWorker == NULL ) m_pWorker = ::CreateThreadpoolTimer(TaskCallback, static_cast<CThreadPoolTaskBase*>(this), pEnviron);
      if( m_pWorker == NULL ) return false;
      m_bCloseWorker = (pCleanupGroup == NULL);
      ::SetThreadpoolTimer(m_pWorker, &m_ftCallbackTimeout, m_dwTaskPeriod, m_dwTaskWindowLength);
      return true;
   }

   BOOL IsTaskTimerSet() const
   {
      _ASSERTE(m_pWorker);
      return ::IsThreadpoolTimerSet(m_pWorker);
   }

   void SetTaskTimer(const FILETIME ftTimeout, DWORD dwPeriod = 0, DWORD dwWindowLength = 0) throw()
   {
      m_ftCallbackTimeout = ftTimeout;
      m_dwTaskPeriod = dwPeriod;
      m_dwTaskWindowLength = dwWindowLength;      
      if( m_pWorker != NULL ) ::SetThreadpoolTimer(m_pWorker, &m_ftCallbackTimeout, m_dwTaskPeriod, m_dwTaskWindowLength);
   }

   void SetTaskTimer(DWORD dwTimeoutMS, DWORD dwPeriod = 0, DWORD dwWindowLength = 0) throw()
   {
      ULONGLONG qwResult = (ULONGLONG) -((__int64) dwTimeoutMS) * 10000I64;
      m_ftCallbackTimeout.dwLowDateTime  = (DWORD) (qwResult & 0xFFFFFFFF);
      m_ftCallbackTimeout.dwHighDateTime = (DWORD) (qwResult >> 32);
      m_dwTaskPeriod = dwPeriod;
      m_dwTaskWindowLength = dwWindowLength;
      if( m_pWorker != NULL ) ::SetThreadpoolTimer(m_pWorker, &m_ftCallbackTimeout, m_dwTaskPeriod, m_dwTaskWindowLength);
   }

   void StopTaskTimer() throw()
   {
      _ASSERTE(m_pWorker);
      ::SetThreadpoolTimer(m_pWorker, NULL, 0, 0);
   }

   void CloseTask() throw()
   {
      if( m_pWorker != NULL && m_bCloseWorker ) ::CloseThreadpoolTimer(m_pWorker);
      m_pWorker = NULL;
   }

   void WaitForTask(BOOL bCancelPending = FALSE) throw()
   {
      _ASSERTE(m_pWorker);
      ::WaitForThreadpoolTimerCallbacks(m_pWorker, bCancelPending);
   }

   operator PTP_TIMER()
   {
      return m_pWorker;
   }

   static VOID CALLBACK TaskCallback(PTP_CALLBACK_INSTANCE pInstance, PVOID pParam, PTP_TIMER /*pTimer*/)
   {
      CThreadPoolTaskBase* pBase = reinterpret_cast<CThreadPoolTaskBase*>(pParam); pBase;
      T* pT = static_cast<T*>(pParam); pT;
      pT->InitTaskBase(pInstance);
      pT->Run();
   }

   void Run()
   {
      // TODO: Must override this!
   }
};


template< typename T >
class CThreadpoolIoTask : public CThreadPoolTaskBase
{
public:
   PTP_IO m_pWorker;
   HANDLE m_hFile;
   bool m_bCloseWorker;

   CThreadpoolIoTask() : m_pWorker(NULL), m_bCloseWorker(false), m_hFile(INVALID_HANDLE_VALUE)
   {
   }

   ~CThreadpoolIoTask()
   {
      CloseTask();
   }

   bool AddToThreadpool(PTP_CALLBACK_ENVIRON pEnviron, PTP_CLEANUP_GROUP pCleanupGroup) throw()
   {
      _ASSERTE(m_hFile!=INVALID_HANDLE_VALUE);
      if( m_pWorker == NULL ) m_pWorker = ::CreateThreadpoolIo(m_hFile, TaskCallback, static_cast<CThreadPoolTaskBase*>(this), pEnviron);
      if( m_pWorker == NULL ) return false;
      m_bCloseWorker = (pCleanupGroup == NULL);
      ::StartThreadpoolIo(m_pWorker);
      return true;
   }

   void SetFile(HANDLE hFile)
   {
      _ASSERTE(m_pWorker==NULL);
      _ASSERTE(hFile!=INVALID_HANDLE_VALUE);
      m_hFile = hFile;
   }

   void CancelTask() throw()
   {
      _ASSERTE(m_pWorker);
      return ::CancelThreadpoolIo(m_pWorker);
   }

   void CloseTask() throw()
   {
      if( m_pWorker != NULL && m_bCloseWorker ) ::CloseThreadpoolIo(m_pWorker);
      m_pWorker = NULL;
   }

   void WaitForTask(BOOL bCancelPending = FALSE) throw()
   {
      _ASSERTE(m_pWorker);
      ::WaitForThreadpoolIoCallbacks(m_pWorker, bCancelPending);
   }

   operator PTP_IO()
   {
      return m_pWorker;
   }

   static VOID CALLBACK TaskCallback(PTP_CALLBACK_INSTANCE pInstance,
                PVOID pParam,
                PVOID pOverlapped,
                ULONG uIoResult,
                ULONG_PTR uNumberOfBytesTransferred,
                PTP_IO /*Io*/)
   {
      CThreadPoolTaskBase* pBase = reinterpret_cast<CThreadPoolTaskBase*>(pParam); pBase;
      T* pT = static_cast<T*>(pBase); pT;
      pT->InitTaskBase(pInstance);
      pT->Run(pOverlapped, uIoResult, uNumberOfBytesTransferred);
   }

   void Run(PVOID /*pOverlapped*/, ULONG /*uIoResult*/, ULONG_PTR /*uNumberOfBytesTransferred*/)
   {
      // TODO: Must override this!
   }
};


template< typename T >
class CThreadpoolCleanupGroupT
{
public:
   PTP_CLEANUP_GROUP m_pCleanupGroup;

   CThreadpoolCleanupGroupT()
   {      
      m_pCleanupGroup = ::CreateThreadpoolCleanupGroup();
   }

   CThreadpoolCleanupGroupT(PTP_CLEANUP_GROUP pCleanupGroup) : m_pCleanupGroup(pCleanupGroup)
   {      
   }

   ~CThreadpoolCleanupGroupT()
   {
      Close();
   }

   bool IsNull() const
   {
      return (m_pCleanupGroup == NULL);
   }

   operator PTP_CLEANUP_GROUP()
   {
      return m_pCleanupGroup;
   }

   void Close()
   {
      if( m_pCleanupGroup == NULL ) return;
      ::CloseThreadpoolCleanupGroup(m_pCleanupGroup);
      m_pCleanupGroup = NULL;
   }

   void Attach(PTP_CLEANUP_GROUP pCleanupGroup)
   {
      Close();
      m_pCleanupGroup = pCleanupGroup;
   }

   PTP_CLEANUP_GROUP Detach()
   {
      PTP_CLEANUP_GROUP pTemp = m_pCleanupGroup;
      m_pCleanupGroup = NULL;
      return pTemp;
   }

   void WaitForGroup(BOOL bCancelWaiting = FALSE)
   {
      _ASSERTE(m_pCleanupGroup);
      ::CloseThreadpoolCleanupGroupMembers(m_pCleanupGroup, bCancelWaiting, static_cast<T*>(this));
   }

   void CancelGroup()
   {
      _ASSERTE(m_pCleanupGroup);
      ::CloseThreadpoolCleanupGroupMembers(m_pCleanupGroup, TRUE, static_cast<T*>(this));
   }

   static VOID NTAPI CleanupCallback(PVOID pObjectContext, PVOID pCleanupContext)
   {
      T* pT = reinterpret_cast<T*>(pCleanupContext); pT;
      CThreadPoolTaskBase* pBase = reinterpret_cast<CThreadPoolTaskBase*>(pObjectContext); pObjectContext;
      pT->OnCleanupCallback(pBase);
   }

   void OnCleanupCallback(CThreadPoolTaskBase* /*pBase*/)
   {
   }  
};


class CThreadpoolCleanupGroup : public CThreadpoolCleanupGroupT<CThreadpoolCleanupGroup>
{ 
public:
   CThreadpoolCleanupGroup()
   {
   }

   CThreadpoolCleanupGroup(PTP_CLEANUP_GROUP pCleanupGroup) : CThreadpoolCleanupGroupT(pCleanupGroup)
   {
   }
};


#ifndef _ATL_THREADPOOL_NO_AUTODELETE

class CThreadpoolAutoDeleteCleanupGroup : public CThreadpoolCleanupGroupT<CThreadpoolAutoDeleteCleanupGroup>
{ 
public:
   CThreadpoolAutoDeleteCleanupGroup()
   {
   }

   CThreadpoolAutoDeleteCleanupGroup(PTP_CLEANUP_GROUP pCleanupGroup) : CThreadpoolCleanupGroupT(pCleanupGroup)
   {
   }

   void OnCleanupCallback(CThreadPoolTaskBase* pBase)
   {
      delete pBase;
   }  
};

#endif // _ATL_THREADPOOL_NO_AUTODELETE


class CThreadpool
{
public:
   PTP_POOL m_pPool;
   TP_CALLBACK_ENVIRON m_CBEnviron;
   PTP_CLEANUP_GROUP m_pActiveGroup;
   CThreadpoolCleanupGroup m_MainCleanupGroup;

   CThreadpool() : m_pPool(NULL), m_MainCleanupGroup(NULL), m_pActiveGroup(NULL)
   {
   }

   ~CThreadpool()
   {
      Destroy();
   }

   BOOL Create()
   {
      Destroy();
      ::InitializeThreadpoolEnvironment(&m_CBEnviron);
      m_pPool = ::CreateThreadpool(NULL);
      if( m_pPool == NULL ) return FALSE;
      ::SetThreadpoolCallbackPool(&m_CBEnviron, m_pPool);
      return TRUE;
   }

   void Destroy() throw()
   {
      if( !m_MainCleanupGroup.IsNull() ) {
         m_MainCleanupGroup.CancelGroup(); 
         m_MainCleanupGroup.Close();
      }
      if( m_pPool != NULL ) ::CloseThreadpool(m_pPool);
      m_pPool = NULL;
   }

   bool IsNull() const
   {
      return (m_pPool == NULL);
   }

   operator PTP_POOL()
   {
      return m_pPool;
   }

   PTP_POOL Detach()
   {
      PTP_POOL hTemp = m_pPool;
      m_pPool = NULL;
      return hTemp;
   }

   BOOL SetMinimumThreadCount(DWORD dwMin) throw()
   {
      _ASSERTE(m_pPool);
      _ASSERTE(dwMin<0x80000000);
      return ::SetThreadpoolThreadMinimum(m_pPool, dwMin);
   }

   void SetMaximumThreadCount(DWORD dwMax) throw()
   {
      _ASSERTE(m_pPool);
      _ASSERTE(dwMax<0x80000000);
      ::SetThreadpoolThreadMaximum(m_pPool, dwMax);
   }

   void SetMaximumThreadCountToCPUs(DWORD dwScale = 1) throw()
   {
      SetMaximumThreadCount(dwScale * GetNumberOfProcessors());
   }

   BOOL SetThreadPersistent() throw()
   {
      SetMaximumThreadCount(1);
      return SetMinimumThreadCount(1);
   }

   DWORD GetNumberOfProcessors() const
   {
      SYSTEM_INFO si = { 0 };
      ::GetSystemInfo(&si);
      return si.dwNumberOfProcessors;
   }

   void SetLibraryRef(HINSTANCE hModule)
   {
      _ASSERTE(m_pPool);
      ::SetThreadpoolCallbackLibrary(&m_CBEnviron, hModule);
   }

   void MarkTasksAsLongRunning()
   {
      _ASSERTE(m_pPool);
      ::SetThreadpoolCallbackRunsLong(&m_CBEnviron);
   }

   void ActivateCleanupGroup(PTP_CLEANUP_GROUP pCleanupGroup, PTP_CLEANUP_GROUP_CANCEL_CALLBACK pCallback = NULL)
   {
      ::SetThreadpoolCallbackCleanupGroup(&m_CBEnviron, pCleanupGroup, pCallback);
      m_pActiveGroup = pCleanupGroup;
   }

   template< typename TCallback >
   BOOL AddTaskToPool(TCallback* pTask)
   {
      return pTask->AddToThreadpool(&m_CBEnviron, m_pActiveGroup);
   }

   // Built-in cleanup group (for easy maintenance)

   PTP_CLEANUP_GROUP InitGroup()
   {
      _ASSERTE(m_MainCleanupGroup.IsNull());
      PTP_CLEANUP_GROUP pCleanupGroup = ::CreateThreadpoolCleanupGroup();
      if( pCleanupGroup == NULL ) return NULL;
      m_MainCleanupGroup.Attach(pCleanupGroup);
      ActivateCleanupGroup(pCleanupGroup);
      return pCleanupGroup;
   }

   void WaitForGroup()
   {
      _ASSERTE(!m_MainCleanupGroup.IsNull());
      m_MainCleanupGroup.WaitForGroup();
   }

   void CancelGroup()
   {
      _ASSERTE(!m_MainCleanupGroup.IsNull());
      m_MainCleanupGroup.CancelGroup();
   }
};


#endif // _WIN32_WINNT

#endif // __ATLTHRDX_H__
