/* ---------------------------------------------------------------------------
 * This software is in the public domain, furnished "as is", without technical
 * support, and with no warranty, express or implied, as to its usefulness for
 * any purpose.

 * Author: Wil Braithwaite.
 *
 */

#include "CudaScheduler.h"
#include <cuda_runtime_api.h>
#include <cuda_gl_interop.h>
#include "cuda_utils.h"

#if defined( _WIN32 )

#else
#include <GL/glx.h>
#include <X11/Xlib.h>
#endif

namespace Easy
{
//---------------------------------------------------------------------------------------
Mutex CudaScheduler::lock;
int CudaScheduler::debugging = 0;

//---------------------------------------------------------------------------------------
CudaSchedulerTask::CudaSchedulerTask(CudaScheduler *s, int i, GlContextData glContextData, int deviceIndex)
        :
        _scheduler(s),
        _index(i),
        _deviceIndex(deviceIndex),
        _exitFlag(false),
		_glContextData(glContextData)
{
}

//---------------------------------------------------------------------------------------
#if defined(_WIN32)

bool glCreateContext(GlContextData* data)
{
	HGLRC hglrc = wglCreateContext (data->hdc);
	if (hglrc == 0)
	{
		DWORD err = GetLastError();
		std::cerr << Stringf("Failed on wglCreateContext: ") << err << std::endl;
        return false;
	}

	if (data->hglrc && wglShareLists(data->hglrc, hglrc) == 0)
	{
		DWORD err = GetLastError();
		std::cerr << Stringf("Failed on wglShareLists: ") << err << std::endl;
        return false;
	}
	/*
    // set to the old context!
    wglMakeCurrent(data->hdc, data->hglrc);
    if(glewInit() != GLEW_OK)
    {
        std::cerr << Stringf("Failed to initialize GLEW") << std::endl;
        return false;
    }
    */
    data->hglrc = hglrc;
	std::cout << Stringf("Created GL context(%p) for display(%p) ", data->hglrc, data->hdc) << std::endl;
	return true;
}

//---------------------------------------------------------------------------------------
void glDestroyContext(GlContextData* data)
{
	if (!data->hdc)
		return;

    assert(data->hglrc);
	std::cout << Stringf("Destroying GL context(%p) for display(%p)", data->hglrc, data->hdc) << std::endl;
    wglDeleteContext(data->hglrc);

	data->hdc = 0;
	data->hglrc= 0;
}

//---------------------------------------------------------------------------------------
void glStartContext(GlContextData* data)
{
	if(!data->hdc)
		return;

	assert(data->hglrc);
    wglMakeCurrent(data->hdc, data->hglrc);
}

//---------------------------------------------------------------------------------------
void glEndContext(GlContextData* data)
{
	if(!data->hdc)
		return;

    wglMakeCurrent(data->hdc, 0);
}

//---------------------------------------------------------------------------------------
#else

bool glCreateContext(GlContextData* data)
{
    Display* xdisplay = data->xdisplay;
    char* displayName = DisplayString(xdisplay);
    int xscreen = DefaultScreen(xdisplay);
    Window xwindow = RootWindow(xdisplay, xscreen);

    GLint attributes[] = {GLX_RGBA, GLX_DEPTH_SIZE, 24, GLX_DOUBLEBUFFER, None};
    XVisualInfo *xvisual = glXChooseVisual(xdisplay, xscreen, attributes);
    if(xvisual == NULL)
    {
        std::cerr << Stringf("Failed to choose GLX-visual") << std::endl;
        return false;
    }
    XSync(xdisplay,false);
    GLXContext glxcontext = glXCreateContext(xdisplay, xvisual, data->glxcontext, true);
    if(glxcontext == NULL)
    {
        std::cerr << Stringf("Failed to initialize GL-context for ") << displayName << std::endl;
        return false;
    }
    XSync(xdisplay,false);
    if(!glXMakeCurrent(xdisplay, xwindow, glxcontext))
    {
        std::cerr << Stringf("Failed to make GL-context current for ") << displayName << std::endl;
        return false;
    }
    XSync(xdisplay,false);
    std::cout << Stringf("Created GL-context for ") << displayName << std::endl;
    if(glewInit() != GLEW_OK)
    {
        std::cerr << Stringf("Failed to initialize GLEW") << std::endl;
        return false;
    }

	data->glxcontext = glxcontext;
	data->xwindow = xwindow;

	std::cout << Stringf("Creating GL context(%p) for display(%p) ", glxcontext, xdisplay) << std::endl;
	return true;
}

//---------------------------------------------------------------------------------------
void glDestroyContext(GlContextData* data)
{
	if (!data->xdisplay)
		return;

    assert(data->glxcontext);
	std::cout << Stringf("Destroying GLX display(%p) and context(%p)", data->xdisplay, data->glxcontext) << std::endl;
    glXDestroyContext( data->xdisplay, data->glxcontext);
}

//---------------------------------------------------------------------------------------
void glStartContext(GlContextData* data)
{
	if (!data->xdisplay)
		return;

	assert(data->glxcontext);
	assert(data->xwindow);
    glXMakeCurrent(data->xdisplay, data->xwindow, data->glxcontext);
}

void glEndContext(GlContextData* data)
{
	if(!data->xdisplay)
		return;

    glXMakeCurrent(data->xdisplay, None, 0);
}

#endif

//---------------------------------------------------------------------------------------
void CudaSchedulerTask::onExecute()
{
	if (CudaScheduler::debugging > 0)
		std::cout << Stringf("task[%d] GPU[%d] is ready...", _index, _deviceIndex) << std::endl;

    NVPARTICLES_CUDA_SAFE_CALL(cudaSetDeviceFlags(cudaDeviceMapHost));

#ifndef _WIN32

    if (_scheduler->useGL)
    {
#if defined(_WIN32)
		_glContextData.hdc = _scheduler->_hdc;
		_glContextData.hglrc = _scheduler->_hglrc;
#else
		_glContextData.xdisplay = _scheduler->_display;
		_glContextData.glxcontext = _scheduler->_glcontext;
		_glContextData.xwindow = 0;
#endif

		CudaScheduler::lock.claim();
		glCreateContext(&_glContextData);
		CudaScheduler::lock.release();
    }
#endif

    /// CAVEAT:
    // this is another thread.
    // if context access happens at the same time as a GL call on the host, it can cause hanging!
	glStartContext(&_glContextData);

    if(_scheduler->useGL)
    {
		NVPARTICLES_CUDA_SAFE_CALL(cudaGLSetGLDevice(_deviceIndex));
	}
	else
    {
        NVPARTICLES_CUDA_SAFE_CALL(cudaSetDevice(_deviceIndex));
    }


    _readyLock.release();

    for (;;)
    {
        if(CudaScheduler::debugging > 2)
            std::cout << Stringf("task[%d] GPU[%d] is waiting...", _index, _deviceIndex) << std::endl;

        // Wait until a job exists
        _scheduler->_itemSemaphore.wait();

        if (_exitFlag)
            break;

        // get the next job that can run
        _scheduler->_listLock.claim();

        // get next job index...
        int nextJobIndex = -1;
        CudaSchedulerJob* job = NULL;
        for(int i=0; i<(int)_scheduler->_job.size(); ++i)
        {
            job = _scheduler->_job[i];
            if(job->_deviceIndex == -1 || job->_deviceIndex == _index)
            {
                nextJobIndex = i;
                break;
            }
        }

        if(nextJobIndex >= 0)
        {
            job->_deviceIndex = _deviceIndex;

            // remove the job pointer (but don't delete the job data)
            _scheduler->_job.erase(_scheduler->_job.begin()+nextJobIndex);

            ++_scheduler->_nBusy;
        }

        _scheduler->_listLock.release();

        if (nextJobIndex == -1)
        {
            // should we put the job back?
            /// not the most efficient (I should probably have three job queues (device1, device2, and deviceany) ????
            _scheduler->_itemSemaphore.signal();
            continue;
        }

        if(CudaScheduler::debugging > 0)
            std::cout << Stringf("task[%d] GPU[%d] is executing job(%p)...", _index, _deviceIndex, job) << std::endl;


        if(!job->onExecute(this))
        {
            std::cerr << Stringf("task[%d] GPU[%d] job(%p) error!", _index, _deviceIndex, job) << std::endl;
        }

        if(CudaScheduler::debugging > 1)
            std::cout << Stringf("task[%d] GPU[%d] completed job(%p)...", _index, _deviceIndex, job) << std::endl;

        // tell the scheduler that we've completed another job
        _scheduler->_listLock.claim();
        ++_scheduler->_nDone;

        // if we've completed all submitted jobs then tell scheduler
        if (--_scheduler->_nBusy == 0)
        {
            if (_scheduler->_nSubmitted==_scheduler->_nDone)
                _scheduler->_allDone.signal();
        }

        _scheduler->_listLock.release();

        if (job->_deleteAtExit)
            delete job;
    }

	if(CudaScheduler::debugging > 0)
		std::cout << Stringf("task[%d] GPU[%d] is exiting.", _index, _deviceIndex) << std::endl;

    // this clears resources for ALL threads in current process.
    // don't do this because we don't want to mess things up for other plug-ins!
    NVPARTICLES_CUDA_SAFE_CALL(cudaDeviceReset());

	glEndContext(&_glContextData);
	glDestroyContext(&_glContextData);
}

//---------------------------------------------------------------------------------------
#if defined(_WIN32)
CudaScheduler::CudaScheduler(int startDevice, int numDevices, HDC display, HGLRC glcontext)
#else
CudaScheduler::CudaScheduler(int startDevice, int numDevices, Display* display, GLXContext glcontext)
#endif
        :
        _itemSemaphore(0),
        _allDone(1),
        _startDevice(startDevice)
{
#if defined(_WIN32)
	_display = display;
	_glcontext = glcontext;
	useGL = (_display != 0);
#else
	_display = display;
    _glcontext = glcontext;
	useGL = (_display != 0);
#endif

    int nTask=0;
    int gpuCount=0;

    NVPARTICLES_CUDA_SAFE_CALL(cudaGetDeviceCount(&gpuCount));

    nTask = max(((numDevices==0)?gpuCount:((numDevices<0)?-numDevices:numDevices)),1);

	if(CudaScheduler::debugging > 0)
		std::cout << "Creating " << nTask << " tasks" << std::endl;

#if defined(_WIN32)
    if (useGL)
    {
        wglMakeCurrent(NULL, NULL);
    }
#endif

    _nSubmitted = _nDone = 0;
    _nBusy=0;

    for (int i=0; i<nTask; ++i)
    {
        CudaSchedulerTask *st;
        int device = _startDevice;
        if (numDevices < 0)
            device -= i;
        else
            device += i;

        while (device < 0)
            device += gpuCount;

		// make a glcontext for this task...
		GlContextData glContextData;

#ifdef _WIN32

#if defined(_WIN32)
		glContextData.hdc = _display;
		glContextData.hglrc = _glcontext;
#else
		glContextData.xdisplay = _display;
		glContextData.glxcontext = _glcontext;
		glContextData.xwindow = 0;
#endif

		if (useGL)
		{
			bool rc = glCreateContext(&glContextData);
			if(!rc)
			{
				std::cerr << Stringf("Unable to create openGL context") << std::endl;
				abort();
				exit(1);
			}
		}
#endif

        task.push_back(st=new CudaSchedulerTask(this, i, glContextData, device%gpuCount));

        st->_readyLock.claim();

        st->execute();

        // wait for worker thread to finish its initialization.
        st->_readyLock.claim();
    }

    // wait for all the tasks to all start.

#if defined(_WIN32)
    if (useGL)
    {
		// restore the gl context...
        wglMakeCurrent(_display, _glcontext);
    }
#endif
}

//---------------------------------------------------------------------------------------
CudaScheduler::~CudaScheduler()
{
    if (task.size()>0)
    {
        _allDone.wait();

        endTasks();

        // wait for them to finish gracefully
        for (unsigned int i=0;i<task.size();++i)
            task[i]->wait();

        for (unsigned int i=0;i<task.size();++i)
            delete task[i];

        task.clear();
    }
}

//---------------------------------------------------------------------------------------
void CudaScheduler::endTasks()
{
    for (unsigned int i=0;i<task.size();++i)
        task[i]->_exitFlag = true;

    // tell them to stop waiting...
    for (unsigned int i=0;i<task.size();++i)
        _itemSemaphore.signal();
}

//---------------------------------------------------------------------------------------
void CudaScheduler::waitAllDone()
{
    if (task.size()>0)
    {
        _allDone.wait();
        _allDone.signal();
    }
}

//---------------------------------------------------------------------------------------
void CudaScheduler::add(CudaSchedulerJob *j, int index, bool owner)
{
    if (task.size()>0)
    {
        _listLock.claim();

        ++_nSubmitted;

        while (_allDone.value()>0)
        {
            _allDone.wait();
        }

        j->_deviceIndex = (index%task.size());
        j->_deleteAtExit = owner;
        _job.push_back(j);

        _listLock.release();

        _itemSemaphore.signal();
    }
    else
    {
        j->onExecute(NULL);
        delete j;
    }
}

//---------------------------------------------------------------------------------------
}
