/* 
 * Copyright (c) 2024, Ashok P. Nadkarni
 * All rights reserved.
 *
 * See the file LICENSE for license
 *
 * COM server implementation
 */

#include "twapi.h"
#include "twapi_com.h"

/*
 * IDispatch server implementation
 */
static HRESULT STDMETHODCALLTYPE Twapi_ComServer_QueryInterface(
    IDispatch *this,
    REFIID riid,
    void **ifcPP);
static ULONG STDMETHODCALLTYPE Twapi_ComServer_AddRef(IDispatch *this);
static ULONG STDMETHODCALLTYPE Twapi_ComServer_Release(IDispatch *this);
static HRESULT STDMETHODCALLTYPE Twapi_ComServer_GetTypeInfoCount(
    IDispatch *this,
    UINT *pctP);
static HRESULT STDMETHODCALLTYPE Twapi_ComServer_GetTypeInfo(
    IDispatch *this,
    UINT tinfo,
    LCID lcid,
    ITypeInfo **tiPP);
static HRESULT STDMETHODCALLTYPE Twapi_ComServer_GetIDsOfNames(
    IDispatch *this,
    REFIID   riid,
    LPOLESTR *namesP,
    UINT namesc,
    LCID lcid,
    DISPID *rgDispId);
static HRESULT STDMETHODCALLTYPE Twapi_ComServer_Invoke(
    IDispatch *this,
    DISPID dispIdMember,
    REFIID riid,
    LCID lcid,
    WORD flags,
    DISPPARAMS *dispparamsP,
    VARIANT *resultvarP,
    EXCEPINFO *excepP,
    UINT *argErrP);


/* Vtbl for Twapi_ComServer */
static struct IDispatchVtbl Twapi_ComServer_Vtbl = {
    Twapi_ComServer_QueryInterface,
    Twapi_ComServer_AddRef,
    Twapi_ComServer_Release,
    Twapi_ComServer_GetTypeInfoCount,
    Twapi_ComServer_GetTypeInfo,
    Twapi_ComServer_GetIDsOfNames,
    Twapi_ComServer_Invoke
};

/*
 * TBD - does this (related methods) need to be made thread safe?
 */
typedef struct Twapi_ComServer {
    interface IDispatch idispP; /* Must be first field */
    IID iid;                    /* IID for this interface. TBD - needed ? We only implement IDispatch, not the vtable for any other IID */
    int refc;                   /* Ref count */
    TwapiInterpContext *ticP;   /* Interpreter and related context */
    Tcl_Obj *memids;            /* List mapping member ids to names */
    Tcl_Obj *cmd;               /* Stores the callback command prefix */
} Twapi_ComServer;


static HRESULT STDMETHODCALLTYPE Twapi_ClassFactory_QueryInterface(
    IClassFactory *this, REFIID riid, void **ifcPP);
static ULONG STDMETHODCALLTYPE Twapi_ClassFactory_AddRef(IClassFactory *this);
static ULONG STDMETHODCALLTYPE Twapi_ClassFactory_Release(IClassFactory *this);
static HRESULT STDMETHODCALLTYPE Twapi_ClassFactory_CreateInstance (
    IClassFactory* , LPUNKNOWN pUnkOuter, REFIID riid, LPVOID* ppvObject);
static HRESULT STDMETHODCALLTYPE Twapi_ClassFactory_LockServer (
 IClassFactory *, BOOL fLock);

/* Vtbl for Twapi_ClassFactory */
static struct IClassFactoryVtbl Twapi_ClassFactory_Vtbl = {
    Twapi_ClassFactory_QueryInterface,
    Twapi_ClassFactory_AddRef,
    Twapi_ClassFactory_Release,
    Twapi_ClassFactory_CreateInstance,
    Twapi_ClassFactory_LockServer
};

typedef struct Twapi_ClassFactory {
    interface IClassFactory ifacP; /* Must be first field */
    CLSID clsid;                   /* CLSID for this class */
    int refc;                   /* Ref count */
    TwapiInterpContext *ticP;   /* Interpreter and related context */
    Tcl_Obj *memids;            /* List mapping member names to integer ids
                                   for the objects generated by this factory */
    Tcl_Obj *cmd;               /* Stores the callback command prefix for
                                   creating objects */
} Twapi_ClassFactory;

#define ObjFromIClassFactory(p_) ObjFromOpaque((p_), "IClassFactory")
#define ObjToIClassFactory(ip_, obj_, ifc_) \
    ObjToOpaque((ip_), (obj_), (ifc_), "IClassFactory")

static HRESULT TwapiMapErrorToHRESULT(int hr)
{
    /* Evals may return Tcl codes or HRESULTs. If it is not a failure
       code, assume it is a Tcl code and map it to generic HRESULT error */
    if (! FAILED(hr)) {
        hr = E_FAIL;
    }
    return hr;
}

static void TwapiComServerShutdown(Tcl_Interp *interp) {
    /*
     * Called if there are no more objects anywhere in the process,
     * Signals to the script that the process can be shut down. It is up to
     * the script as to whether that actually happens. For example,
     * a COM server may choose to exit at that point. An application
     * that is a COM client but has COM event callbacks may continue
     * to run.
     * TBD - how to deal with multiple threads and multiple interpreters?
     */
    
    /*
     * This routine is called from a COM proxy/stub context. Do not
     * want to directly call a script or set a traced variable via
     * Tcl_UpdateLinkedVar since in those cases we may? have to save
     * and restore interp state. Instead we just set a variable and
     * expect the interp to be vwaiting on it.
     * TBD - is this the best way ?
     */
    if (interp && ! Tcl_InterpDeleted(interp)) {
        Tcl_SetVar(interp, "::twapi::com_shutdown_signal", "1", 0);
    }
}

static Tcl_Obj *TwapiComServerMemIdToName(Twapi_ComServer *me, DISPID dispid)
{
    Tcl_Obj **objs;
    Tcl_Size i, nobjs;

    if (ObjGetElements(NULL, me->memids, &nobjs, &objs) != TCL_OK ||
        (nobjs & 1) != 0)
        return NULL;            /* Should not happen */

    for (i = 0; i < nobjs-1; i += 2) {
        int memid;
        if (ObjToInt(NULL, objs[i], &memid) != TCL_OK)
            return NULL;        /* Should not happen */
        if (memid == dispid)
            return objs[i+1];
    }

    return NULL;
}

static HRESULT TwapiComServerNameToMemId(Twapi_ComServer *me, LPCWSTR name, DISPID *idP)
{
    Tcl_Obj **objs;
    Tcl_Size i, nobjs;

    if (ObjGetElements(NULL, me->memids, &nobjs, &objs) != TCL_OK ||
        (nobjs & 1) != 0)
        return E_NOTIMPL;            /* Should not happen */

    for (i = 1; i < nobjs; i += 2) {
        int memid;
        if (lstrcmpiW(name, ObjToWinChars(objs[i])) == 0) {
            if (ObjToInt(NULL, objs[i-1], &memid) != TCL_OK)
                return E_NOTIMPL;        /* Should not happen */
            *idP = memid;
            return S_OK;
        }
    }

    return E_NOTIMPL;            /* Should not happen */
}


static HRESULT STDMETHODCALLTYPE Twapi_ComServer_QueryInterface(
    IDispatch *this,
    REFIID riid,
    void **ifcPP)
{
    /* TBD - Should not check against this->iid ? Because we only implement IUNknown and IDispatch, not the vtable for the iid (e.g. in case of a sink) */

    if (!IsEqualIID(riid, &((Twapi_ComServer *)this)->iid) &&
        !IsEqualIID(riid, &IID_IUnknown) &&
        !IsEqualIID(riid, &IID_IDispatch)) {
        /* Not a supported interface */
        *ifcPP = NULL;
        return E_NOINTERFACE;
    }

    this->lpVtbl->AddRef(this);
    *ifcPP = this;
    return S_OK;
}


static ULONG STDMETHODCALLTYPE Twapi_ComServer_AddRef(IDispatch *this)
{
    ((Twapi_ComServer *)this)->refc += 1;
    return ((Twapi_ComServer *)this)->refc;
}

static ULONG STDMETHODCALLTYPE Twapi_ComServer_Release(IDispatch *this)
{
    Twapi_ComServer *me = (Twapi_ComServer *) this;

    me->refc -= 1;
    if (((Twapi_ComServer *)this)->refc == 0) {
        if (me->memids)
            ObjDecrRefs(me->memids);
        if (me->cmd)
            ObjDecrRefs(me->cmd);
        if (me->ticP) {
            /* If no more objects (of any type) we can shutdown the process */
            if (CoReleaseServerProcess() == 0)
                TwapiComServerShutdown(me->ticP->interp);
            TwapiInterpContextUnref(me->ticP, 1);
        }

        TwapiFree(this);
        return 0;
    } else
        return ((Twapi_ComServer *)this)->refc;
}

static HRESULT STDMETHODCALLTYPE Twapi_ComServer_GetTypeInfoCount
(
    IDispatch *this,
    UINT *pctP
)
{
    /* We do not provide type information */
    if (pctP)
        *pctP = 0;
    return S_OK;
}


static HRESULT STDMETHODCALLTYPE Twapi_ComServer_GetTypeInfo(
    IDispatch *this,
    UINT tinfo,
    LCID lcid,
    ITypeInfo **tiPP)
{
    return E_NOTIMPL;
}

static HRESULT STDMETHODCALLTYPE Twapi_ComServer_GetIDsOfNames(
    IDispatch *this,
    REFIID   riid,
    LPOLESTR *namesP,
    UINT namesc,
    LCID lcid,
    DISPID *rgDispId)
{
    HRESULT hr;
    UINT i;

    if (namesc == 0)
        return DISP_E_UNKNOWNNAME; 

    if (riid && !IsEqualIID(&IID_NULL, riid))
        return DISP_E_UNKNOWNINTERFACE;

    for (i = 0 ; i < namesc; ++i)
        rgDispId[i] = DISPID_UNKNOWN;

    /* TBD - need to take lcid into account ? */
    hr = TwapiComServerNameToMemId((Twapi_ComServer *) this, namesP[0], &rgDispId[0]);
    if (hr != S_OK)
        return hr;

    /* If parameter names were asked for, return DISP_E_UNKNOWNNAME indicates
       could not retrieve all names */
    return namesc > 1 ? DISP_E_UNKNOWNNAME : S_OK;
}

static HRESULT STDMETHODCALLTYPE Twapi_ComServer_Invoke(
    IDispatch *this,
    DISPID dispid,
    REFIID riid,
    LCID lcid,
    WORD flags,
    DISPPARAMS *dispparamsP,
    VARIANT *retvarP,
    EXCEPINFO *excepP,
    UINT *argErrP)
{
    Twapi_ComServer *me = (Twapi_ComServer *) this;
    HRESULT hr;
    Tcl_Obj **cmdobjv = NULL;
    Tcl_Obj **cmdprefixv;
    Tcl_Size  i, cmdobjc;
    Tcl_InterpState savedState;
    Tcl_Interp *interp;
    Tcl_Obj *memberNameObj;
    int argErr = -1;
    BSTR errorBstr = NULL;

    /* TBD - should we clear retvarP right at start ? */

    if (me == NULL || me->ticP == NULL || me->ticP->interp == NULL)
        return E_POINTER;

    if (me->ticP->thread != Tcl_GetCurrentThread())
        Tcl_Panic("Twapi_ComServer_Invoke called from non-interpreter thread");

    interp = me->ticP->interp;
    if (Tcl_InterpDeleted(interp))
        return E_POINTER;

    if (ObjGetElements(NULL, me->cmd, &cmdobjc, &cmdprefixv) != TCL_OK) {
        /* Internal error - should not happen. Should we log background error?*/
        return E_FAIL;
    }

    if (flags == DISPATCH_PROPERTYPUTREF) {
        /* TBD - better error code */
        return E_FAIL;
    }

    memberNameObj = TwapiComServerMemIdToName(me, dispid);
    if (memberNameObj == NULL) {
        /* Should not really happen. Log internal error ? */
        return E_FAIL;
    }

    /*
     * Before eval'ing, addref ourselves so we don't get deleted in a
     * recursive callback
     */
    this->lpVtbl->AddRef(this);


    /* Note we will tack on member name plus dispparms */
    i = cmdobjc + 1;
    if (dispparamsP)
        i += dispparamsP->cArgs;
    cmdobjv = MemLifoPushFrame(me->ticP->memlifoP, i * sizeof(*cmdobjv), NULL);

    for (i = 0; i < cmdobjc; ++i) {
        cmdobjv[i] = cmdprefixv[i];
        ObjIncrRefs(cmdobjv[i]);
    }

    ObjIncrRefs(memberNameObj);
    cmdobjv[cmdobjc] = memberNameObj;
    cmdobjc += 1;

    /* Add the passed parameters */
    if (dispparamsP) {
        /* Note parameters are in reverse order */
        for (i = dispparamsP->cArgs - 1; i >= 0 ; --i) {
            /* Verify that we can handle the parameter types */
            if (dispparamsP->rgvarg[i].vt & VT_BYREF) {
                /* TBD - need to be able to handle BYREF for inout/out params */
                if (0) {
                    hr = DISP_E_TYPEMISMATCH;
                    goto vamoose;
                }
            }
            cmdobjv[cmdobjc] = ObjFromVARIANT(&dispparamsP->rgvarg[i], 1);
            ObjIncrRefs(cmdobjv[cmdobjc]);
            ++cmdobjc;
        }
    }

    /* TBD - is this safe as we are being called from the message dispatch
       loop? Or should we queue to pending callback queue ? But in that
       case we cannot get results back as we can't block in this thread
       as the script invocation will also be in this thread. Also, is
       the Tcl_SaveInterpState/RestoreInterpState really necessary ?
       Note tclWinDde also evals in this fashion.
    */
    savedState = Tcl_SaveInterpState(interp, TCL_OK);
    Tcl_ResetResult (interp);
    hr = Tcl_EvalObjv(interp, cmdobjc, cmdobjv, TCL_EVAL_GLOBAL);
    if (hr != TCL_OK) {
        hr = TwapiMapErrorToHRESULT(hr);
        ObjToBSTR(interp, ObjGetResult(interp), &errorBstr);
        Tcl_BackgroundError(interp);
    } else {
        /* TBD - check if interp deleted ? */

        /* TBD - appropriately init retvarP from ObjGetResult keeping
         * in mind that the retvarP by be BYREF as well.
         */
        if (retvarP) {
            VARTYPE ret_vt;
            Tcl_Obj *retObj = ObjGetResult(interp);
            
            VariantInit(retvarP); /* TBD - should be VariantClear ? */
            ret_vt = ObjTypeToVT(retObj);

            if (ObjToVARIANT(interp, retObj, retvarP, ret_vt) != TCL_OK) {
                hr = E_FAIL;
                goto restore_and_return;
            }
            if (retvarP->vt == VT_DISPATCH || retvarP->vt == VT_UNKNOWN) {
                /* When handing out interfaces, must increment their refs */
                if (retvarP->punkVal != NULL)
                    retvarP->punkVal->lpVtbl->AddRef(retvarP->punkVal);
            }
        }
        hr = S_OK;
    }

restore_and_return:
    Tcl_RestoreInterpState(interp, savedState);

vamoose:
    if (FAILED(hr)) {
        if (excepP) {
            TwapiZeroMemory(excepP, sizeof(*excepP));
            excepP->scode = hr;
            if (errorBstr) {
                excepP->bstrDescription = errorBstr;
                errorBstr = NULL;
            }
            hr = DISP_E_EXCEPTION;
        }
        if (argErrP)
            *argErrP = argErr;
    }

    if (errorBstr) {
        SysFreeString(errorBstr);
        errorBstr = NULL;
    }
    if (cmdobjv) {
        for (i = 0; i < cmdobjc; ++i) {
            ObjDecrRefs(cmdobjv[i]);
        }
    }

    MemLifoPopFrame(me->ticP->memlifoP);

    /* Undo the AddRef we did before */
    this->lpVtbl->Release(this);
    /* this/me may be invalid at this point! Make sure we don't access them */

    return hr;
}

/*
 * Called from a script create an automation object.
 * Returns the IDispatch interface.
 */
int Twapi_ComServerObjCmd(
    ClientData clientdata,
    Tcl_Interp *interp,
    int objc,
    Tcl_Obj *CONST objv[])
{
    TwapiInterpContext *ticP = (TwapiInterpContext*) clientdata;
    Twapi_ComServer *comserverP;
    IID iid;
    HRESULT hr;
    Tcl_Obj **memidObjs;
    Tcl_Size i, nmemids;

    TWAPI_ASSERT(ticP->interp == interp);

    if (objc != 4) {
        Tcl_WrongNumArgs(interp, 1, objv, "IID MEMIDMAP CMD");
        return TCL_ERROR;
    }
    
    hr = IIDFromString(ObjToWinChars(objv[1]), &iid);
    if (FAILED(hr))
        return Twapi_AppendSystemError(interp, hr);

    if (ObjGetElements(interp, objv[2], &nmemids, &memidObjs) != TCL_OK)
        return TCL_ERROR;
    if (nmemids & 1)
        goto invalid_memids;    /* Need even number of elements */
    for (i = 0; i < nmemids-1; i += 2) {
        int memid;
        if (ObjToInt(interp, memidObjs[i], &memid) != TCL_OK)
            return TCL_ERROR;
    }

    /* Memory is freed when the object is released */
    comserverP = TwapiAlloc(sizeof(*comserverP));

    comserverP->memids = objv[2];
    ObjIncrRefs(objv[2]);


    /* Fill in the cmdargs slots from the arguments */
    comserverP->idispP.lpVtbl = &Twapi_ComServer_Vtbl;
    comserverP->iid = iid;
    TwapiInterpContextRef(ticP, 1);
    comserverP->ticP = ticP;
    comserverP->refc = 1;
    ObjIncrRefs(objv[3]);
    comserverP->cmd = objv[3];

    CoAddRefServerProcess();

    ObjSetResult(interp, ObjFromIUnknown(comserverP));

    return TCL_OK;

invalid_memids:
    ObjSetStaticResult(interp, "Invalid memid map");
    return TCL_ERROR;
}

/*
 * Class factory implementation
 */

static HRESULT STDMETHODCALLTYPE Twapi_ClassFactory_QueryInterface(
    IClassFactory *this,
    REFIID riid,
    void **ifcPP)
{
    if (!IsEqualIID(riid, &IID_IClassFactory) &&
        !IsEqualIID(riid, &IID_IUnknown)) {
        /* Not a supported interface */
        *ifcPP = NULL;
        return E_NOINTERFACE;
    }

    this->lpVtbl->AddRef(this);
    *ifcPP = this;
    return S_OK;
}


static ULONG STDMETHODCALLTYPE Twapi_ClassFactory_AddRef(IClassFactory *this)
{
    // TBD
    ((Twapi_ClassFactory *)this)->refc += 1;
    return ((Twapi_ClassFactory *)this)->refc;
}

static ULONG STDMETHODCALLTYPE Twapi_ClassFactory_Release(IClassFactory *this)
{
    // TBD

    Twapi_ClassFactory *me = (Twapi_ClassFactory *) this;

    me->refc -= 1;
    if (((Twapi_ClassFactory *)this)->refc == 0) {
        if (me->memids)
            ObjDecrRefs(me->memids);
        if (me->cmd)
            ObjDecrRefs(me->cmd);
        if (me->ticP)
            TwapiInterpContextUnref(me->ticP, 1);
        TwapiFree(this);
        return 0;
    } else
        return ((Twapi_ClassFactory *)this)->refc;
}

static HRESULT STDMETHODCALLTYPE Twapi_ClassFactory_CreateInstance (
    IClassFactory* this,
    LPUNKNOWN pUnkOuter,
    REFIID riid,
    LPVOID* ppv)
{
    Twapi_ClassFactory *me = (Twapi_ClassFactory *)this;
    HRESULT hr;
    Tcl_Obj **cmdobjv = NULL;
    Tcl_Obj **cmdprefixv;
    Tcl_Size  i, cmdobjc;
    Tcl_InterpState savedState;
    Tcl_Interp *interp;

    *ppv = 0;
    if (pUnkOuter)
        return CLASS_E_NOAGGREGATION;

    if (me == NULL || me->ticP == NULL || me->ticP->interp == NULL)
        return E_POINTER;

    if (me->ticP->thread != Tcl_GetCurrentThread())
        Tcl_Panic("Twapi_ClassFactory_CreateInstance called from non-interpreter thread");

    interp = me->ticP->interp;
    if (Tcl_InterpDeleted(interp))
        return E_POINTER;

    if (ObjGetElements(NULL, me->cmd, &cmdobjc, &cmdprefixv) != TCL_OK) {
        /* Internal error - were invalid command. Should we log background error?*/
        return E_FAIL;
    }

    /*
     * Before eval'ing, addref ourselves so we don't get deleted in a
     * recursive callback
     */
    this->lpVtbl->AddRef(this);


    /* Note we will tack on IID */
    cmdobjv = MemLifoPushFrame(me->ticP->memlifoP, (cmdobjc+1) * sizeof(*cmdobjv), NULL);
    
    for (i = 0; i < cmdobjc; ++i) {
        cmdobjv[i] = cmdprefixv[i];
        ObjIncrRefs(cmdobjv[i]);
    }

    cmdobjv[cmdobjc] = ObjFromGUID(riid);
    ObjIncrRefs(cmdobjv[cmdobjc]);
    cmdobjc += 1;

                 
    /* TBD - is this safe as we are being called from the message dispatch
       loop? Or should we queue to pending callback queue ? But in that
       case we cannot get results back as we can't block in this thread
       as the script invocation will also be in this thread. Also, is
       the Tcl_SaveInterpState/RestoreInterpState really necessary ?
       Note tclWinDde also evals in this fashion.
    */
    savedState = Tcl_SaveInterpState(interp, TCL_OK);
    Tcl_ResetResult (interp);
    hr = Tcl_EvalObjv(interp, cmdobjc, cmdobjv, TCL_EVAL_GLOBAL);
    if (hr != TCL_OK) {
        hr = TwapiMapErrorToHRESULT(hr);
        Tcl_BackgroundError(interp);
    } else {
        /* TBD - check if interp deleted ? */
        void *pv;
        if (ObjToIUnknown(interp, ObjGetResult(interp), &pv) == TCL_OK) {
            /* We are not AddRef'ing the interface because that should
               have been done by the creation script.
            */
            *ppv = pv;
            hr = S_OK;
        } else {
            hr = E_FAIL;
        }
    }

    Tcl_RestoreInterpState(interp, savedState);

    if (cmdobjv) {
        for (i = 0; i < cmdobjc; ++i) {
            ObjDecrRefs(cmdobjv[i]);
        }
    }

    MemLifoPopFrame(me->ticP->memlifoP);

    /* Undo the AddRef we did before */
    this->lpVtbl->Release(this);
    /* this/me may be invalid at this point! Make sure we don't access them */

    return hr;
}

static HRESULT STDMETHODCALLTYPE Twapi_ClassFactory_LockServer (
 IClassFactory * this,
 BOOL lock)
{
    if (lock)
        CoAddRefServerProcess();
    else if (CoReleaseServerProcess() == 0) {
        /* If no more objects (of any type) we can shutdown the process */
        Twapi_ClassFactory *me = (Twapi_ClassFactory *)this;
        if (me->ticP) {
            TwapiComServerShutdown(me->ticP->interp);
        }
    }
    return S_OK;
}

/*
 * Called from a script create a class factory
 * Returns the IClassFactory interface.
 */
int Twapi_ClassFactoryObjCmd(
    ClientData clientdata,
    Tcl_Interp *interp,
    int objc,
    Tcl_Obj *CONST objv[])
{
    TwapiInterpContext *ticP = (TwapiInterpContext*) clientdata;
    Twapi_ClassFactory *cfP;
    CLSID clsid;
    HRESULT hr;

    TWAPI_ASSERT(ticP->interp == interp);

    if (objc != 3) {
        Tcl_WrongNumArgs(interp, 1, objv, "IID CMD");
        return TCL_ERROR;
    }
    
    hr = CLSIDFromString(ObjToWinChars(objv[1]), &clsid);
    if (FAILED(hr))
        return Twapi_AppendSystemError(interp, hr);

    /* Memory is freed when the object is released */
    cfP = TwapiAlloc(sizeof(*cfP));

    /* Fill in the cmdargs slots from the arguments */
    cfP->ifacP.lpVtbl = &Twapi_ClassFactory_Vtbl;
    cfP->clsid = clsid;
    TwapiInterpContextRef(ticP, 1);
    cfP->ticP = ticP;
    cfP->refc = 1;
    ObjIncrRefs(objv[2]);
    cfP->cmd = objv[2];

    ObjSetResult(interp, ObjFromIClassFactory(cfP));

    return TCL_OK;
}
