// traps.hpp

#pragma once

namespace traps{

static inline DWORD *GetThunkAddres(HINSTANCE hInstance, LPCSTR strExporter, LPCSTR strFunc)
{
	if (hInstance)
	{
		BYTE *image=(BYTE*)hInstance;
		PIMAGE_OPTIONAL_HEADER ioh=(PIMAGE_OPTIONAL_HEADER)(image
			+PIMAGE_DOS_HEADER(image)->e_lfanew+4+sizeof(IMAGE_FILE_HEADER));
		
		//      strExporter

		PIMAGE_IMPORT_DESCRIPTOR iid=(PIMAGE_IMPORT_DESCRIPTOR)(image
			+ioh->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);

		for (;iid->Name;iid++)
			if (strcmpi((LPCSTR)(image+iid->Name),strExporter)==0)
			{
				//      strFunc

				PIMAGE_THUNK_DATA isd=(PIMAGE_THUNK_DATA)(image
					+iid->OriginalFirstThunk);

				for (DWORD *addr=(DWORD*)(image+iid->FirstThunk);isd->u1.AddressOfData;isd++,addr++)
					if (strcmpi((LPCSTR)(image+(DWORD)isd->u1.AddressOfData->Name),strFunc)==0)
						return addr;
			}
	}
	return 0;
}

// ***

template <class CTrap>
	class CTrapSwapper
{
protected:
	CTrap &Trap;

public:
	CTrapSwapper(CTrap &aTrap):Trap(aTrap)
	{
		Trap.Swap();
	}
	CTrapSwapper(CTrap *pTrap):Trap(*pTrap)
	{
		Trap.Swap();
	}
	~CTrapSwapper()
	{
		Trap.Swap();
	}
};

//    

template <class TData>
	class CTrap
{
protected:
	TData* pAddr;
	TData nOldData;

	void Swap()
	{
		if (pAddr)
		{
			TData Temp;
			//_asm {int 3}
			Temp=*pAddr,*pAddr=nOldData,nOldData=Temp;
		}
	}
	void Trap()
	{
		if (pAddr)
		{
			DWORD n;
			VirtualProtect(pAddr,4,PAGE_EXECUTE_READWRITE,&n);
			Swap();
		}
	}

	friend class CTrapSwapper<CTrap<TData> >;

	CTrap(LPVOID pvAddr):
		pAddr((TData*)pvAddr)
	{
		// Trap()       nOldData
	}
	CTrap()
	{
		// Trap()       pAddr  nOldData
	}

public:
	CTrap(LPVOID pvAddr, TData nNewData):
		pAddr((TData*)pvAddr), nOldData(nNewData)
	{
		Trap();
	}
	~CTrap()
	{
		Swap();
	}
};

//    

class CTrapAddr: public CTrap<DWORD>
{
public:
	//       

	CTrapAddr(HINSTANCE hImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc):
		CTrap<DWORD>(GetThunkAddres(hImporter, strExporter, strFunc),(DWORD)newFunc) {}
	CTrapAddr(LPCSTR strImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc):
		CTrap<DWORD>(GetThunkAddres(GetModuleHandle(strImporter), strExporter, strFunc),(DWORD)newFunc) {}

	//      

	CTrapAddr(LPVOID pObject, int nOffset, LPVOID newFunc):
		CTrap<DWORD>(reinterpret_cast<DWORD**>(pObject)[0/*VMT*/]+nOffset,(DWORD)newFunc) {}

	template <typename PFunc>
		PFunc OldFunc(PFunc)
	{
		return (PFunc)(LPVOID)nOldData;
	}
};

//          

#pragma pack(push,1)
struct COneByteAndOffset
{
	BYTE Command;
	DWORD Offset;
};
#pragma pack(pop)

class CTrapJumpHelper: public CTrap<COneByteAndOffset>
{
protected:
	CTrapJumpHelper(LPVOID pFunc,LPVOID pNewFunc):
		CTrap<COneByteAndOffset>(pFunc)
	{
		nOldData.Command=0xE9;
		nOldData.Offset=(DWORD)pNewFunc-(DWORD)pFunc-5;
		Trap();
	}
};

class CTrapJump: public CTrapJumpHelper
{
protected:
	static LPVOID GetAddr(HINSTANCE hImporter, LPCSTR strExporter, LPCSTR strFunc)
	{
		DWORD *Addr=GetThunkAddres(hImporter, strExporter, strFunc);
		return Addr ? (LPVOID)*Addr : 0;
	}

	friend class CTrapSwapper<CTrapJump>;

public:
	CTrapJump(LPVOID pFunc, LPVOID newFunc):
		CTrapJumpHelper(pFunc,newFunc) {}

	//      

	CTrapJump(HINSTANCE hModule, LPCSTR strFuncName, LPVOID newFunc):
		CTrapJumpHelper(GetProcAddress(hModule,strFuncName),newFunc) {}
	CTrapJump(LPCSTR strModule, LPCSTR strFuncName, LPVOID newFunc):
		CTrapJumpHelper(GetProcAddress(GetModuleHandle(strModule),strFuncName),newFunc) {}

	//      

	CTrapJump(HINSTANCE hImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc):
		CTrapJumpHelper(GetAddr(hImporter,strExporter,strFunc),newFunc) {}
	CTrapJump(LPCSTR strImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc):
		CTrapJumpHelper(GetAddr(GetModuleHandle(strImporter),strExporter,strFunc),newFunc) {}
};

} // namespace traps
