// PyGatewayBase - the IUnknown Gateway Interface

#include "stdafx.h"

#include "PythonCOM.h"
#include "PyFactory.h"

#include "PythonCOMServer.h"


// {25D29CD0-9B98-11d0-AE79-4CF1CF000000}
extern const GUID IID_IInternalUnwrapPythonObject = 
	{ 0x25d29cd0, 0x9b98, 0x11d0, { 0xae, 0x79, 0x4c, 0xf1, 0xcf, 0x0, 0x0, 0x0 } };

extern void PyCom_LogF(const char *fmt, ...);
#define LogF PyCom_LogF

// #define DEBUG_FULL
static LONG cGateways = 0;
LONG _PyCom_GetGatewayCount(void)
{
	return cGateways;
}

/////////////////////////////////////////////////////////////////////////////
//

PyGatewayBase::PyGatewayBase(PyObject *instance)
{
	InterlockedIncrement(&cGateways);
	m_cRef = 1;
	m_pPyObject = instance;
	Py_XINCREF(instance); // instance should never be NULL - but whats an X between friends!

	PyCom_DLLAddRef();

#ifdef DEBUG_FULL
	LogF("PyGatewayBase: created %s", m_pPyObject ? m_pPyObject->ob_type->tp_name : "<NULL>");
#endif
}

PyGatewayBase::~PyGatewayBase()
{
	InterlockedDecrement(&cGateways);
#ifdef DEBUG_FULL
	LogF("PyGatewayBase: deleted %s", m_pPyObject ? m_pPyObject->ob_type->tp_name : "<NULL>");
#endif

	if ( m_pPyObject )
	{
		PyCom_EnterPython();
		{
			Py_DECREF(m_pPyObject);
		}
		PyCom_LeavePython();
	}

	PyCom_DLLReleaseRef();
}

STDMETHODIMP PyGatewayBase::QueryInterface(
	REFIID iid,
	void ** ppv
	)
{
#ifdef DEBUG_FULL
	{
		USES_CONVERSION;
		OLECHAR oleRes[128];
		StringFromGUID2(iid, oleRes, sizeof(oleRes));
		LogF("PyGatewayBase::QueryInterface: %s", OLE2T(oleRes));
	}
#endif

	*ppv = NULL;

	// If our native interface, return this.
    if ( IsEqualIID(iid, IID_NULL) || IsEqualIID(iid, GetIID())) {
		AddRef();
		*ppv = ThisAsIID(GetIID());
		return S_OK;
	}

    if ( IsEqualIID(iid, IID_IUnknown) ||
		 IsEqualIID(iid, IID_IDispatch) )
	{
//		LogF("PyGatewayBase::QueryInterface: requested IUnknown/IDispatch. returning self.");
		*ppv = (IDispatch *)this; // IDispatch * == IUnknown * for gateways.
		AddRef();
		return S_OK;
	}
	if ( IsEqualIID(iid, IID_ISupportErrorInfo) )
	{
		*ppv = (ISupportErrorInfo *)this;
		AddRef();
		return S_OK;
	}
	if ( IsEqualIID(iid, IID_IInternalUnwrapPythonObject) )
	{
		// Special IID for unwrapping (ie, downcasting) a Python object
		*ppv = (IInternalUnwrapPythonObject *)this;
		AddRef();
		return S_OK;

		//return PyCom_MakeRegisteredGatewayObject(iid, m_pPyObject, ppv);
	}
//	if (SUCCEEDED(QueryNativeGatewayInterface(this, iid, ppv))) {
//		LogF("PyGatewayBase::QueryInterface: supported natively.");
//		return S_OK ; // all ref'd up for me.
//	}

	// Call the Python policy to see if it (says it) supports the interface
	long supports = 0;
	PyCom_EnterPython();
	{
		PyObject * ob = PyCom_PyIIDObjectFromIID(iid);
		if ( !ob )
		{
			PyCom_LeavePython();
			return E_OUTOFMEMORY;
		}

		PyObject *result = PyObject_CallMethod(m_pPyObject, "_QueryInterface_",
											   "O", ob);
		Py_DECREF(ob);

		if ( result )
		{
			if (PyInt_Check(result))
				supports = PyInt_AsLong(result);
			else {
				IUnknown *pUnk;
				// We know the object is a Python object.  And we _cant_ ask for the
				// actual IID QI'd on, as in some cases (ie, connection points) actually
				// return an IDispatch object (ie, the IID may not be known by the framework)
				// And we can't call with IDispatch or IUnknown, as for true gateways,
				// (IDispatch *)object != (ISpecificInterface *)
				// So we call with IID_NULL.  When a _Python_ gateway is QI'd with
				// IID_NULL, it will return its _native_ interface - ie, the most specialised
				// it can.
				if (PyCom_InterfaceFromPyObject(result, IID_NULL, (void **)&pUnk, FALSE)) {
					*ppv = pUnk;
					supports = 1;
				}
			}
			PyErr_Clear(); // ignore exceptions during conversion 
			Py_DECREF(result);
		}
		else
		{
//			PyRun_SimpleString("import traceback;traceback.print_exc()");
			PyErr_Clear();	// ### what to do with exceptions? ... 
		}
	}
	PyCom_LeavePython();

	if ( supports != 1 )
		return E_NOINTERFACE;

	// Make a new gateway object (returning its result code)
	return *ppv ? S_OK : PyCom_MakeRegisteredGatewayObject(iid, m_pPyObject, ppv);
}

STDMETHODIMP_(ULONG) PyGatewayBase::AddRef(void)
{
	return InterlockedIncrement(&m_cRef);
}

STDMETHODIMP_(ULONG) PyGatewayBase::Release(void)
{
	LONG cRef = InterlockedDecrement(&m_cRef);
	if ( cRef == 0 )
		delete this;
	return cRef;
}

STDMETHODIMP PyGatewayBase::GetTypeInfoCount(
	UINT FAR* pctInfo
	)
{
	if (pctInfo==NULL)
		return E_POINTER;
	/* ### eventually, let Python be able to return type info */

	*pctInfo = 0;
	return S_OK;
}

STDMETHODIMP PyGatewayBase::GetTypeInfo(
	UINT itinfo,
	LCID lcid,
	ITypeInfo FAR* FAR* pptInfo
	)
{
	if (pptInfo==NULL)
		return E_POINTER;
    *pptInfo = NULL;

	/* ### eventually, let Python be able to return type info */

	return DISP_E_BADINDEX;
}

static HRESULT getids_setup(
	UINT cNames,
	OLECHAR FAR* FAR* rgszNames,
	LCID lcid,
	PyObject **pPyArgList,
	PyObject **pPyLCID
	)
{
	USES_CONVERSION;
	PyObject *argList = PyTuple_New(cNames);
	if ( !argList )
	{
		PyErr_Clear();	/* ### what to do with exceptions? ... */
		return E_OUTOFMEMORY;
	}

	for ( UINT i = 0; i < cNames; i++ )
	{
		/* ### correct conversion function? */
		char *s = OLE2T(rgszNames[i]);

		PyObject *ob = PyString_FromString(s);
		if ( !ob )
		{
			PyErr_Clear();	/* ### what to do with exceptions? ... */
			Py_DECREF(argList);
			return E_OUTOFMEMORY;
		}

		/* Note: this takes our reference for us (even if it fails) */
		if ( PyTuple_SetItem(argList, i, ob) == -1 )
		{
			PyErr_Clear();	/* ### what to do with exceptions? ... */
			Py_DECREF(argList);
			return E_FAIL;
		}
	}

	/* use the double stuff to keep lcid unsigned... */
	PyObject * py_lcid = PyLong_FromDouble((double)lcid);
	if ( !py_lcid )
	{
		PyErr_Clear();	/* ### what to do with exceptions? ... */
		Py_DECREF(argList);
		return E_FAIL;
	}

	*pPyArgList = argList;
	*pPyLCID = py_lcid;

	return S_OK;
}

static HRESULT getids_finish(
	PyObject *result,
	UINT cNames,
	DISPID FAR* rgdispid
	)
{
	if ( !result )
		return PyCom_HandlePythonFailureToCOM();

	if ( !PySequence_Check(result) )
	{
		Py_DECREF(result);
		return E_FAIL;
	}

	UINT count = PyObject_Length(result);
	if ( count != cNames )
	{
		PyErr_Clear();	/* ### toss any potential exception */
		Py_DECREF(result);
		return E_FAIL;
	}

	HRESULT hr = S_OK;
	for ( UINT i = 0; i < cNames; ++i )
	{
		PyObject *ob = PySequence_GetItem(result, i);
		if ( !ob )
		{
			PyErr_Clear();	/* ### what to do with exceptions? ... */
			Py_DECREF(result);
			return E_FAIL;
		}
		if ( (rgdispid[i] = PyInt_AsLong(ob)) == DISPID_UNKNOWN )
			hr = DISP_E_UNKNOWNNAME;

		Py_DECREF(ob);
	}

	Py_DECREF(result);

	return hr;
}

STDMETHODIMP PyGatewayBase::GetIDsOfNames(
	REFIID refiid,
	OLECHAR FAR* FAR* rgszNames,
	UINT cNames,
	LCID lcid,
	DISPID FAR* rgdispid
	)
{
#ifdef DEBUG_FULL
	LogF("PyGatewayBase::GetIDsOfNames");
#endif

	HRESULT hr;
	PyObject *argList;
	PyObject *py_lcid;

	PyCom_EnterPython();
	{
		hr = getids_setup(cNames, rgszNames, lcid, &argList, &py_lcid);
		if ( SUCCEEDED(hr) )
		{
			PyObject *result = PyObject_CallMethod(m_pPyObject,
												   "_GetIDsOfNames_",
												   "OO", argList, py_lcid);
			Py_DECREF(argList);
			Py_DECREF(py_lcid);

			hr = getids_finish(result, cNames, rgdispid);
		}
	}
	PyCom_LeavePython();

	return hr;
}

static HRESULT invoke_setup(
	DISPPARAMS FAR* params,
	LCID lcid,
	PyObject **pPyArgList,
	PyObject **pPyLCID
	)
{
	PyObject *argList = PyTuple_New(params->cArgs);
	if ( !argList )
	{
		PyErr_Clear();	/* ### what to do with exceptions? ... */
		return E_OUTOFMEMORY;
	}

	PyObject *ob;
	VARIANTARG FAR *pvarg;
	UINT i;
	for ( pvarg = params->rgvarg, i = params->cArgs; i--; ++pvarg )
	{
		ob = PyCom_MakeVariantToPyObject(pvarg);
		if ( !ob )
		{
			PyErr_Clear();	/* ### what to do with exceptions? ... */
			Py_DECREF(argList);
			return E_OUTOFMEMORY;
		}

		/* Note: this takes our reference for us (even if it fails) */
		if ( PyTuple_SetItem(argList, i, ob) == -1 )
		{
			PyErr_Clear();	/* ### what to do with exceptions? ... */
			Py_DECREF(argList);
			return E_FAIL;
		}
	}

	/* use the double stuff to keep lcid unsigned... */
	PyObject * py_lcid = PyLong_FromDouble((double)lcid);
	if ( !py_lcid )
	{
		PyErr_Clear();	/* ### what to do with exceptions? ... */
		Py_DECREF(argList);
		return E_FAIL;
	}

	*pPyArgList = argList;
	*pPyLCID = py_lcid;
	return S_OK;
}

static HRESULT invoke_finish(
	PyObject *result,
	VARIANT FAR* pVarResult,
	UINT FAR* puArgErr
	)
{
	HRESULT hr;

	if ( PyNumber_Check(result) )
	{
		hr = PyInt_AsLong(result);
		Py_DECREF(result);
		return hr;
	}
	if ( !PySequence_Check(result) )
	{
		Py_DECREF(result);
		return E_FAIL;
	}

	PyObject *ob = PySequence_GetItem(result, 0);
	if ( !ob )
	{
		PyErr_Clear();	/* ### what to do with exceptions? ... */
		Py_DECREF(result);
		return E_FAIL;
	}
	hr = PyInt_AsLong(ob);
	Py_DECREF(ob);

	int count = PyObject_Length(result);
	if ( count > 0 )
	{
		if ( puArgErr )
		{
			ob = PySequence_GetItem(result, 1);
			if ( !ob )
			{
				PyErr_Clear();	/* ### what to do with exceptions? ... */
				Py_DECREF(result);
				return E_FAIL;
			}

			*puArgErr = PyInt_AsLong(ob);
			Py_DECREF(ob);
		}

		if ( pVarResult )
		{
			ob = PySequence_GetItem(result, 2);
			if ( !ob )
			{
				PyErr_Clear();	/* ### what to do with exceptions? ... */
				Py_DECREF(result);
				return E_FAIL;
			}

			BOOL success = PyCom_MakePyObjectToVariant(ob, pVarResult);
			/* for now: */ PyErr_Clear();
			if ( !success )
			{
				PyErr_Clear();
				hr = E_FAIL;
			}
			Py_DECREF(ob);
		}

		if ( count > 3 )
		{
			/* ### copy extra results into associated [out] params */
		}
	}

	Py_DECREF(result);

	return hr;
}

STDMETHODIMP PyGatewayBase::Invoke(
	DISPID dispid,
	REFIID riid,
	LCID lcid,
	WORD wFlags,
	DISPPARAMS FAR* params,
	VARIANT FAR* pVarResult,
	EXCEPINFO FAR* pexcepinfo,
	UINT FAR* puArgErr
	)
{
#ifdef DEBUG_FULL
	LogF("PyGatewayBase::Invoke; dispid=%ld", dispid);
#endif

	HRESULT hr;

	if ( pVarResult )
		V_VT(pVarResult) = VT_EMPTY;

	/* ### for now: no named args unless it is a PUT operation */
	if ( params->cNamedArgs )
	{
		if ( params->cNamedArgs > 1 )
			return DISP_E_NONAMEDARGS;
		if ( params->rgdispidNamedArgs[0] != DISPID_PROPERTYPUT )
			return DISP_E_NONAMEDARGS;
	}

	PyCom_EnterPython();
	{
		PyObject *argList;
		PyObject *py_lcid;
		hr = invoke_setup(params, lcid, &argList, &py_lcid);
		if ( SUCCEEDED(hr) )
		{
			PyObject * result = PyObject_CallMethod(m_pPyObject,
													"_Invoke_",
													"iOiO",
													dispid, py_lcid, wFlags,
													argList);

			Py_DECREF(argList);
			Py_DECREF(py_lcid);

			if ( result==NULL )
				hr = PyCom_HandlePythonFailureToCOM(pexcepinfo);
			else
				hr = invoke_finish(result, pVarResult, puArgErr);
		}
	}
	PyCom_LeavePython();

	return hr;
}

// Extra Python helpers...
static PyObject *do_dispatch(
	PyObject *pPyObject,
	const char *szMethodName,
	const char *szFormat,
	va_list va
	)
{
	// Build the Invoke arguments...
	PyObject *args;
	if ( szFormat )
		args = Py_VaBuildValue((char *)szFormat, va);
	else
		args = PyTuple_New(0);
	if ( !args )
		return NULL;

	// make sure a tuple.
	if ( !PyTuple_Check(args) )
    {
		PyObject *a = PyTuple_New(1);
		if ( a == NULL )
		{
			PyTS_DECREF(args);
			return NULL;
		}
		PyTuple_SET_ITEM(a, 0, args);
		args = a;
    }

	PyObject *method = PyObject_GetAttrString(pPyObject, "_Invoke_");
	if ( !method )
    {
		PyErr_SetString(PyExc_AttributeError, (char *)szMethodName);
		return NULL;
    }

	// Make the call to _Invoke_
	PyObject *result = PyObject_CallFunction(method,
											 "siiO",
											 szMethodName,
											 0,
											 DISPATCH_METHOD,
											 args);
	PyTS_DECREF(method);
	PyTS_DECREF(args);
	if ( !result )
		return NULL;

	if ( !PySequence_Check(result) )
	{
		Py_DECREF(result);
		PyErr_SetString(PyExc_RuntimeError, "bad return type from _Invoke_");
		return NULL;
	}

	PyObject *ob = PySequence_GetItem(result, 2);
	Py_DECREF(result);
	if ( !ob )
	{
		PyErr_SetString(PyExc_RuntimeError, "_Invoke_ return value too short");
		return NULL;
	}

	return ob;
}

PyObject *PyGatewayBase::DispatchViaPolicy(const char *szMethodName, const char *szFormat, ...)
{
	va_list va;

	if ( !m_pPyObject || !szMethodName )
    {
		return OleSetTypeError("The argument is invalid");
    }

	va_start(va, szFormat);
	PyObject *result = do_dispatch(m_pPyObject, szMethodName, szFormat, va);
	va_end(va);

	return result;
}

STDMETHODIMP PyGatewayBase::InvokeViaPolicy(
	const char *szMethodName,
	PyObject **ppResult /* = NULL */,
	const char *szFormat /* = NULL */,
	...
	)
{
	va_list va;

	if ( m_pPyObject == NULL || szMethodName == NULL )
		return E_POINTER;

	va_start(va, szFormat);
	PyObject *result = do_dispatch(m_pPyObject, szMethodName, szFormat, va);
	va_end(va);

	HRESULT hr = PyCom_SetFromPyException(GetIID());

	if ( ppResult )
		*ppResult = result;
	else
		Py_XDECREF(result);

	return hr;
}

STDMETHODIMP PyGatewayBase::InterfaceSupportsErrorInfo(REFIID riid)
{
	if ( IsEqualGUID(riid, GetIID()) )
		return S_OK;

	return S_FALSE;
}

STDMETHODIMP PyGatewayBase::Unwrap(
            /* [out] */ PyObject **pPyObject)
{
	if (pPyObject==NULL)
		return E_POINTER;
	*pPyObject = m_pPyObject;
	Py_INCREF(m_pPyObject);
	return S_OK;
}
