// traps.hpp

#pragma once

namespace traps{

static inline DWORD *GetThunkAddres(HINSTANCE hInstance, LPCSTR strExporter, LPCSTR strFunc)
{
	if (hInstance)
	{
		BYTE *image=(BYTE*)hInstance;
		PIMAGE_OPTIONAL_HEADER ioh=(PIMAGE_OPTIONAL_HEADER)(image
			+PIMAGE_DOS_HEADER(image)->e_lfanew+4+sizeof(IMAGE_FILE_HEADER));
		
		//      strExporter

		PIMAGE_IMPORT_DESCRIPTOR iid=(PIMAGE_IMPORT_DESCRIPTOR)(image
			+ioh->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);

		for (;iid->Name;iid++)
			if (strcmpi((LPCSTR)(image+iid->Name),strExporter)==0)
			{
				//      strFunc

				PIMAGE_THUNK_DATA isd=(PIMAGE_THUNK_DATA)(image
					+iid->OriginalFirstThunk);

				for (DWORD *addr=(DWORD*)(image+iid->FirstThunk);isd->u1.AddressOfData;isd++,addr++)
					if (strcmpi((LPCSTR)(image+(DWORD)isd->u1.AddressOfData->Name),strFunc)==0)
						return addr;
			}
	}
	return 0;
}

// ***

template <class CTrap>
	class CTrapSwapper
{
protected:
	CTrap &Trap;

public:
	CTrapSwapper(CTrap &aTrap):Trap(aTrap)
	{
		Trap.Swap();
	}
	~CTrapSwapper()
	{
		Trap.Swap();
	}
};

//    

template <class TData>
	class CTrap
{
protected:
	TData* pAddr;
	TData nOldAddr;

	void Swap()
	{
		if (pAddr)
		{
			TData Temp;
			Temp=*pAddr,*pAddr=nOldAddr,nOldAddr=Temp;
		}
	}
	void Trap()
	{
		if (pAddr)
		{
			DWORD n;
			VirtualProtect(pAddr,4,PAGE_EXECUTE_READWRITE,&n);
			Swap();
		}
	}

	friend class CTrapSwapper<CTrap<TData> >;
	friend class CTrapJump;

public:
	CTrap(LPVOID pvAddr, TData nNewAddr):
		pAddr(static_cast<TData*>(pvAddr)), nOldAddr(nNewAddr)
	{
		Trap();
	}
	~CTrap()
	{
		Swap();
	}
};

//    

class CTrapAddr: public CTrap<DWORD>
{
public:
	//       

	CTrapAddr(HINSTANCE hImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc):
		CTrap<DWORD>(GetThunkAddres(hImporter, strExporter, strFunc),reinterpret_cast<DWORD>(newFunc)) {}
	CTrapAddr(LPCSTR strImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc):
		CTrap<DWORD>(GetThunkAddres(GetModuleHandle(strImporter), strExporter, strFunc),reinterpret_cast<DWORD>(newFunc)) {}

	//      

	CTrapAddr(LPVOID pObject, int nOffset, LPVOID newFunc):
		CTrap<DWORD>(reinterpret_cast<DWORD**>(pObject)[0]+nOffset,reinterpret_cast<DWORD>(newFunc)) {}

	template <typename PFunc>
		PFunc OldFunc(PFunc)
	{
		return static_cast<PFunc>(reinterpret_cast<LPVOID>(nOldAddr));
	}
};

//          

#pragma pack(push,1)
struct COneByteOffset
{
	BYTE Command;
	DWORD Offset;
};
#pragma pack(pop)

class CTrapJump
{
protected:
	CTrap<COneByteOffset> *pTrap;

	void Trap(LPVOID pFunc,LPVOID pNewFunc)
	{
		COneByteOffset Command = {0xE9, reinterpret_cast<DWORD>(pNewFunc)-reinterpret_cast<DWORD>(pFunc)-5};
		pTrap=new CTrap<COneByteOffset>(pFunc,Command);
	}
	void Swap()
	{
		pTrap->Swap();
	}

	friend class CTrapSwapper<CTrapJump>;

public:
	CTrapJump(LPVOID pFunc, LPVOID newFunc)
	{
		Trap(pFunc,newFunc);
	}

	//      

	CTrapJump(HINSTANCE hModule, LPCSTR strFuncName, LPVOID newFunc)
	{
		Trap(GetProcAddress(hModule, strFuncName),newFunc);
	}
	CTrapJump(LPCSTR strModule, LPCSTR strFuncName, LPVOID newFunc)
	{
		Trap(GetProcAddress(GetModuleHandle(strModule), strFuncName),newFunc);
	}

	//      

	CTrapJump(HINSTANCE hImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc)
	{
		Trap((LPVOID)*GetThunkAddres(hImporter, strExporter, strFunc),newFunc);
	}
	CTrapJump(LPCSTR strImporter, LPCSTR strExporter, LPCSTR strFunc, LPVOID newFunc)
	{
		Trap((LPVOID)*GetThunkAddres(GetModuleHandle(strImporter), strExporter, strFunc),newFunc);
	}

	~CTrapJump()
	{
		delete pTrap;
	}
};

} // namespace traps
