/*
 * Copyright 1993-2012 NVIDIA Corporation.  All rights reserved.
 *
 * Please refer to the NVIDIA end user license agreement (EULA) associated
 * with this source code for terms and conditions that govern your use of
 * this software. Any use, reproduction, disclosure, or distribution of
 * this software and related documentation outside the terms of the EULA
 * is strictly prohibited.
 *
 */

#include "GlutApp.h"
#include "gl_utils.h"
#include "math_utils.h"
#include "std_utils.h"
#include "cuda_utils.h"

#include "NvParticlesManager.h"
#include "NvParticlesProfiler.h"
#include "NvParticlesGrid.h"
#include "NvParticlesPrimitives.h"
#include "NvParticlesForces.h"
#include "NvParticlesParticleSolverImpl.h"
#include "NvParticlesParticleRenderer.h"
#include "NvParticlesParticleContainer.h"

#include "../../solvers/wcsph/Wcsph.h"

#ifdef _WIN32
#include <conio.h>
#endif

#define TEST_PRIMITIVE_SCALE 100

using namespace Easy;
using namespace Easy::NvParticles;

bool drawGrid = false;
bool renderLabels = false;
bool drawPrimitives = true;
bool drawBounds = true;
std::string renderMethod = "spheres";

bool firstTime = false; /// if this is true it will update on load

class CudaParticlesApp : public GlutApp
{
    typedef GlutApp inherited;

    float renderRadiusFactor;
	float colorScale;
    int colorStyle;

    bool initialized;
    NvParticles::Manager* nvParticles;

    ParticleContainer* particleContainer;
    ParticleRenderer* particleRenderer;
    ParticleSolver* particleSolver;

    //NvParticles::Parameters attributes;
    NvParticles::Parameters particleParameters;
    unsigned int currentFrame;
    float ballSpeed;

    int maxParticles;
    int numBenchmarkIterations;
    bool waitForKey;
    int cudaDeviceIndex;

    int testMode, newTestMode;

    float particleSpacing;

public:

    virtual const char *Title()
    {
        return "NvParticles - test";
    }

    CudaParticlesApp(int argc, char **argv, int mode=GLUT_RGB | GLUT_DEPTH | GLUT_DOUBLE | GLUT_MULTISAMPLE)
        :
        inherited(argc, argv, mode), initialized(false)
    {
        maxParticles = 300000;
        numBenchmarkIterations = 0;
        waitForKey = false;
        cudaDeviceIndex = 0;

        particleContainer = 0;
        particleRenderer = 0;
        particleSolver = 0;

		draw_profiler = false;

        nvParticles = new NvParticles::Manager;
        nvParticles->solverFactory.registerType("wcsph", Easy::NvParticles::Wcsph::Solver::creator);

        currentFrame = 0;
        ballSpeed = 0.03f;
        testMode = 0;
        newTestMode = 1;
        particleSpacing = 0;
        renderRadiusFactor = 1;
        colorStyle = 0;
		colorScale = 1.0;

        int rc = parseArgs(argc, argv);
	    if (rc < 0)
	    {
		    printf("Incorrect parameters.\n");
		    exit(1);
	    }
	    else if (rc == 0)
	    {
		    exit(0);
	    }
    }

    void _setTest(int test)
    {
        if (!newTestMode)
            return;

        testMode = test;
        newTestMode = 0;
		srand(0);

		if(test == 1)
        {
            particleSolver->setSolverType("wcsph");

            particleParameters["particleRadius"].setFloat(1);
            particleParameters["restDensity"].setFloat(1000);
            particleParameters["densityThreshold"].setFloat(0);
            particleParameters["deltaTime"].setFloat(0.02);
            particleParameters["deltaTimeUseCfl"].setBool(false);
		    particleParameters["speedOfSound"].setFloat(80);
		    particleParameters["smoothingFactor"].setFloat(0.6);

		    particleParameters["internalScale"].setFloat(1.f);
            particleParameters["artificialViscosity"].setFloat(0.5);
            particleParameters["viscosity"].setFloat(3.5);

            particleParameters["surfaceTension"].setFloat(0.0);
            particleParameters["surfaceTensionThreshold"].setFloat(7);

            particleParameters["xsph"].setFloat(0.5);
			particleParameters["gravity"].setFloat(9.8);

            particleParameters["boundaryStiffness"].setFloat(1000);
			particleParameters["boundaryDamping"].setFloat(0);
			mat44f xform = mat44f::scale(100,100,100) * mat44f::translate(0, 1, 0);
			particleParameters["boundaryMatrix"] = xform;
			particleParameters["boundaryMode"].setInt(1);

            particleParameters["surfaceDistance"].setFloat(0.3);
            particleParameters["negativePressureFactor"].setFloat(0);

            particleParameters["frameRate"].setFloat(1/0.02);

			particleParameters["densityInterpolationIterations"].setFloat(200);

            particleParameters["debugLevel"].setInt(0);

            particleParameters["glTexDisplacementSize"].setInt(0);

			particleContainer->reset();
            particleContainer->emitBox(mat44f::scale(1,1,1) * xform, 2, 0.0f, make_vec4f(1.f), 0);
        }
		else if(test == 2)
        {

            particleSolver->setSolverType("wcsph");

            particleParameters["particleRadius"].setFloat(1);
            particleParameters["restDensity"].setFloat(1000);
            particleParameters["densityThreshold"].setFloat(0);
            particleParameters["artificialViscosity"].setFloat(0.3);
            particleParameters["viscosity"].setFloat(3.5);

            particleParameters["deltaTime"].setFloat(0.02);
		    particleParameters["deltaTimeUseCfl"].setBool(false);

            particleParameters["surfaceDistance"].setFloat(0.2);
            particleParameters["negativePressureFactor"].setFloat(0);

		    particleParameters["speedOfSound"].setFloat(50);
		    particleParameters["smoothingFactor"].setFloat(0.75);
		    particleParameters["internalScale"].setFloat(1);
            particleParameters["surfaceTension"].setFloat(0.0);
            particleParameters["surfaceTensionThreshold"].setFloat(7);
		    particleParameters["xsph"].setFloat(0.5);
            particleParameters["boundaryStiffness"].setFloat(1000);
			particleParameters["gravity"].setFloat(9.8);

			particleParameters["boundaryDamping"].setFloat(1);
			mat44f xform = mat44f::scale(100,100,50) * mat44f::translate(0, 1, 0);
			particleParameters["boundaryMatrix"] = xform;
			//particleParameters["boundaryMode"].SetInt(1);
            particleParameters["boundaryMode"].setInt(2);

			particleParameters["frameRate"].setFloat(1/0.02);

			particleParameters["densityInterpolationIterations"].setFloat(200);

			particleContainer->reset();
			particleContainer->emitBox(mat44f::scale(1,0.1,1) * xform, 2, 0.0f, make_vec4f(1.f), 0);

            renderRadiusFactor = 0.85;
        }
		else if(test == 3)
        {
            particleSolver->setSolverType("wcsph");

			particleParameters["particleRadius"].setFloat(1);

            particleParameters["restDensity"].setFloat(600);
            particleParameters["densityThreshold"].setFloat(0);
            particleParameters["artificialViscosity"].setFloat(0.5);
            particleParameters["viscosity"].setFloat(3.5);
		    particleParameters["deltaTime"].setFloat(0.004);
		    particleParameters["speedOfSound"].setFloat(2);
		    //particleParameters["smoothingFactor"].setFloat(0.85);
		    particleParameters["smoothingFactor"].setFloat(1);

		    particleParameters["internalScale"].setFloat(40.f);
            particleParameters["surfaceTension"].setFloat(0.01);
            particleParameters["surfaceTensionThreshold"].setFloat(7);
		    particleParameters["xsph"].setFloat(0.5);
            particleParameters["boundaryStiffness"].setFloat(1000);
			particleParameters["gravity"].setFloat(9.8);

			particleParameters["boundaryDamping"].setFloat(0);
			mat44f xform = mat44f::scale(100,100,100) * mat44f::translate(0, 1, 0);
			particleParameters["boundaryMatrix"] = xform;
			particleParameters["boundaryMode"].setInt(2);

			particleParameters["frameRate"].setFloat(1/(0.004*1));

			particleContainer->reset();
			particleContainer->emitBox(mat44f::translate(0, 9, 0)*mat44f::scale(3,3,3), 1, 0.0f, make_vec4f(1.f), 0);
        }
		else if(test == 4)
        {
            // periodic water / moving boundary...

            particleSolver->setSolverType("wcsph");

            particleParameters["particleRadius"].setFloat(1); // (kg)
            particleParameters["restDensity"].setFloat(600); // (kg / m^3)
            particleParameters["densityThreshold"].setFloat(0);
            particleParameters["artificialViscosity"].setFloat(0.5);
            particleParameters["viscosity"].setFloat(3.5);

            particleParameters["deltaTime"].setFloat(0.002);
            particleParameters["deltaTimeUseCfl"].setBool(false);
            particleParameters["cflFactor"].setFloat(0.3);

		    particleParameters["speedOfSound"].setFloat(6);
		    particleParameters["smoothingFactor"].setFloat(0.7);
		    particleParameters["internalScale"].setFloat(40.f);
            particleParameters["surfaceTension"].setFloat(0);
            particleParameters["surfaceTensionThreshold"].setFloat(7);
		    particleParameters["xsph"].setFloat(0.5);
            particleParameters["boundaryStiffness"].setFloat(1000);
			particleParameters["gravity"].setFloat(9.8);

			particleParameters["boundaryDamping"].setFloat(1);
			mat44f xform = mat44f::scale(100,100,100) * mat44f::translate(0, 1, 0);
			particleParameters["boundaryMatrix"] = xform;
            particleParameters["boundaryMode"].setInt(2); // wrap the X and Z bounds.

            particleParameters["surfaceDistance"].setFloat(0.6);
			particleParameters["frameRate"].setFloat(1/0.002);

			particleContainer->reset();
            particleContainer->emitBox(mat44f::scale(1,0.2,1) * xform, 2, 0.0f, make_vec4f(1.f), 0);
        }

		currentFrame = 2;

        particleParameters["lifespanMode"].setInt(0);
        particleParameters["lifespan"].setFloat(100);
        particleParameters["lifespanRandom"].setFloat(0);

		float radius = powf(particleParameters.asFloat("particleMass", 1)/particleParameters.asFloat("restDensity"), 1.0f/3.0f);

		NvParticles::Profiler::getSingleton().Reset();

        particleSolver->setTime(0);
        particleSolver->setParameters(particleParameters);
		particleSolver->updateAsync();

        particleSpacing = particleSolver->getParticleSpacing();
		Repaint();
    }

    //------------------------------------------------------------------------------------------
    virtual void OnInit()
    {
		glEnable(GL_MULTISAMPLE);
        glClearColor(0.3f,0.3f,0.3f,0);

		SetPause(true);

        particleContainer = new ParticleContainer();
        particleContainer->setMaxParticles(maxParticles);
        particleContainer->addBuffer(ParticleBufferSpec("id", ParticleBufferSpec::UINT));
        particleContainer->addBuffer(ParticleBufferSpec("position", ParticleBufferSpec::FLOAT4));
        particleContainer->addBuffer(ParticleBufferSpec("velocity", ParticleBufferSpec::FLOAT4));
        particleContainer->addBuffer(ParticleBufferSpec("color", ParticleBufferSpec::FLOAT4));
        particleContainer->addBuffer(ParticleBufferSpec("birthTime", ParticleBufferSpec::FLOAT));

        particleRenderer = new ParticleRenderer();
		particleRenderer->setType("points");
        particleRenderer->resize(256, 256, 60.0f*PI/180.0f);

        particleSolver = new ParticleSolver();
        particleSolver->setCudaDeviceConfig(cudaDeviceIndex, gl::getDisplay(), gl::getContext());
        particleSolver->setContainer(particleContainer);
        //particleSolver->setExportedBuffers("position,velocity,color,density,force");

#if defined(_DEBUG)
        particleSolver->setDebugLevel(2);
#endif
		particleSolver->setSolverType("wcsph");
        particleSolver->setParameters(particleParameters);

        _setTest(newTestMode);

        particleSolver->setParameters(particleParameters);

        particleParameters["time"].setFloat(currentFrame/24.0f);

        initialized = true;

        if (numBenchmarkIterations > 0)
        {
            _reshape(width, height);

            for (int i=0; i<numBenchmarkIterations; ++i)
            {
                STDERR(i);
                DoUpdate();
                _display();
            }

            OnExit();
            exit(0);
        }

    }

    virtual void OnReshape(unsigned int w, unsigned int h, unsigned int fov)
    {
		particleSolver->sync();
        particleRenderer->resize(w, h, fov/(180.f/3.14165));
    }

    virtual void OnUpdate()
    {
        if (!initialized)
            return;

        static Easy::Timer timer;

        NvParticles::Profiler::getSingleton().IncrementFrame();

        if (firstTime || !paused)
        {
            SetTitle(Easy::Stringf("frame=%d",currentFrame).c_str());

            firstTime = false;

		    particleSolver->sync();

			_setTest(newTestMode);

			if (1)
			{
                // update an animated primitive...

				float ballFrame = currentFrame;
				mat44f newBallMat = mat44f::scale(TEST_PRIMITIVE_SCALE/10) * mat44f::translate( 10*(powf(0.5*(sinf(ballFrame*ballSpeed/10)+1), 4)*2-1) * 0.3, 10*(cosf(ballFrame*ballSpeed)+1)*0.25, 10*(cosf(ballFrame*ballSpeed))*0.25) * mat44f::rotateY(ballFrame*ballSpeed/2) * mat44f::rotateX(ballFrame*ballSpeed);// * mat44f::scale(2,2,1) * mat44f::scale(0.5,0.4,0.1) * mat44f::scale(0.5,0.5,0.5);

				NvParticles::Primitive ball;
				memset(&ball, 0, sizeof(NvParticles::Primitive));
				ball.type = Primitive::PRIMITIVE_CAPSULE;
				ball.flags = Primitive::PRIMITIVE_FLAGS_EXTERIOR;
				ball.xform = newBallMat;
				ball.extents = make_vec3f(1, 1, TEST_PRIMITIVE_SCALE/10);

				particleSolver->addPrimitive("ball", ball);
			}

            if (testMode == 4)
            {
                // translate the domain in a circle around the origin...
                mat44f xform = mat44f::translate(sinf(((float)currentFrame/5)/PI)*30, 0, cosf(((float)currentFrame/5)/PI)*30) * mat44f::scale(100, 100, 100) * mat44f::translate(0, 1, 0);
                particleParameters["boundaryMatrix"] = xform;
            }

            if (testMode == 2)
            {
                // birth more particles to maintain the water level!
                mat44f xform = mat44f::scale(100, 100, 50) * mat44f::translate(0, 10, 0);
                particleContainer->emitBox(mat44f::scale(1,0.1,1) * xform, particleSpacing, 0.0f, make_vec4f(1.f), 0);
            }

			// update from gui parameters...
            particleParameters["colorScale"] = colorScale;
            particleParameters["drawBounds"] = drawBounds;
            particleParameters["drawPrimitives"] = drawPrimitives;
            particleParameters["colorStyle"] = colorStyle;

            _updateNvParticles();

            currentFrame += 1;
        }
    }

    void _updateNvParticles()
    {
        particleSolver->setTime(currentFrame/24.0f);
        particleSolver->setParameters(particleParameters);
		particleSolver->updateAsync();
    }

    void _renderNvParticles()
    {
		if(particleRenderer)
		{
			particleRenderer->setType(renderMethod);

            // get the background textures...
			particleRenderer->readColorTexture();
			particleRenderer->readDepthTexture();

            Parameters renderParameters;

            mat44f mat;
		    glGetFloatv(GL_MODELVIEW_MATRIX, (GLfloat*)&mat);
		    renderParameters["modelViewMatrix"] = mat;
		    glGetFloatv(GL_PROJECTION_MATRIX, (GLfloat*)&mat);
		    renderParameters["projectionMatrix"] = mat;

			vec4f renderColor = renderParameters.asVector4("renderer_color", make_vec4f(0.0f, 0.8f, 1.0f,1));

            // sphere:
            renderParameters["renderer_useLighting"].setBool(true);
            renderParameters["renderer_absoluteRadius"].setFloat(0);
			renderParameters["renderer_radiusFactor"].setFloat(renderRadiusFactor);

			// point:
			renderParameters["renderer_streakLength"].setFloat(0); // increase this for streaks.
			renderParameters["renderer_streakWidth"].setFloat(2);
			renderParameters["renderer_useColor"].setBool(false);
			renderParameters["renderer_pointSize"].setInt(4);

            particleRenderer->updateParameters(renderParameters);

			glColor4f(renderColor.x, renderColor.y, renderColor.z, 1.0);

            // render will block until the solver thread is not doing a buffer update...
			particleSolver->render(particleRenderer, drawBounds, drawGrid, drawPrimitives);
		}
    }

    virtual void OnRender()
    {
		glCheckErrors();

        glPushAttrib(GL_ENABLE_BIT);

		glDisable(GL_TEXTURE_2D);
		glDisable(GL_TEXTURE_RECTANGLE);
		glDisable(GL_LIGHTING);
        glDisable(GL_BLEND);
		glDepthFunc(GL_LEQUAL);

		// draw a grid of radius 100...
		glColor4f(0.25,0.25,0.25,0.25);
        gl::drawWirePlane(100,1);
		glColor4f(0,0,0,1);
        gl::drawWirePlane(100,10);

        _renderNvParticles();

        if (1)
        {
            // draw textures for debugging...

            glMatrixMode(GL_MODELVIEW);

            glPushMatrix();
            glLoadIdentity();

            glMatrixMode(GL_PROJECTION);
            glPushMatrix();
            glLoadIdentity();
            int w = width;
            FORCE_MAX(w,64);
            int h = height;
            FORCE_MAX(h,64);
            glOrtho(0,w,0,h,-1,1);
            glDisable(GL_DEPTH_TEST);

            glColor4f(1,1,1,1);
            glActiveTexture(GL_TEXTURE0);
            glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_REPLACE);
            particleRenderer->renderDebugTextures(0,0,0.10f);

			glMatrixMode(GL_PROJECTION);
			glPopMatrix();
			glMatrixMode(GL_MODELVIEW);
			glPopMatrix();
            glCheckErrors();
        }

        glPopAttrib();
    }


	virtual void OnOverlay()
	{
		if (draw_profiler)
			DrawProfile(&NvParticles::Profiler::getSingleton());

		if (paused)
		{
			glColor4f(1,1,1,1);
			Easy::gl::drawSolidRectangle(0,height-16, width, height-32-4);
			glColor4f(1,0,0,1);
			Easy::gl::drawText(width/2,height-32,"PAUSED", GLUT_BITMAP_9_BY_15);
		}
	}

    void DoUpdate()
    {
        firstTime = true;
        OnUpdate();
        Repaint();
    }

    virtual void OnIdle()
    {
        OnUpdate();
    }

    void emitBlock(vec3f origin, vec3f res)
    {
        particleContainer->emitBox(mat44f::scale(res.x, res.y, res.z) * mat44f::translate(origin.x, origin.y, origin.z), particleSpacing, 0.0f, make_vec4f(1.f), 0);
    }

    void Reset()
    {
        currentFrame = 0;

        particleContainer->clear();

        NvParticles::Profiler::getSingleton().Reset();
    }

    virtual bool OnKey(unsigned char key)
    {
        switch (key)
        {
        case 13:
            // update the simulation on enter.
            DoUpdate();
            return true;

        case 'o':
            if (draw_profiler)
                draw_profiler = false;
			else
				draw_profiler = true;
            break;

        case 'd':
            if(particleContainer->getParticleCount() > 128)
                particleContainer->dump(-1, particleContainer->maxParticles/100);
            else
                particleContainer->dump(-1, 1);
            return true;

        case '1':
            newTestMode = 1;
            return true;

        case '2':
            newTestMode = 2;
            return true;

        case '3':
            newTestMode = 3;
            return true;

        case '4':
            newTestMode = 4;
            return true;

        case 'e':
            emitBlock(make_vec3f(10, 80, 0), make_vec3f(16));
            return true;

        case '6':
            renderMethod = "points";
            Repaint();
            return true;

        case '7':
            renderMethod = "spheres";
            Repaint();
            return true;

        case 'c':
            drawPrimitives = !drawPrimitives;
            particleParameters["drawPrimitives"] = drawPrimitives;
            Repaint();
            return true;

        case 'b':
            drawBounds = !drawBounds;
            particleParameters["drawBounds"] = drawBounds;
            Repaint();
            return true;

        case 'g':
            drawGrid = !drawGrid;
            particleParameters["drawGrid"] = drawGrid;
            Repaint();
            return true;

        case 'h':
            ballSpeed -= 0.01;
            STDERR(ballSpeed);
            return true;
        case 'j':
            ballSpeed += 0.01;
            STDERR(ballSpeed);
            return true;

        case '\'':
        {
            colorStyle++;
			STDERR(colorStyle);
            return true;
        }

        case '=':
            renderRadiusFactor += 0.1;
			STDERR(renderRadiusFactor);
            DoUpdate();
            Repaint();
            return true;
        case '-':
            renderRadiusFactor -= 0.1;
            if (renderRadiusFactor < 0.01)
                renderRadiusFactor = 0.01;
			STDERR(renderRadiusFactor);
            DoUpdate();
            Repaint();
            return true;

        case '[':
			colorScale -= 0.1f;
			STDERR(colorScale);
			DoUpdate();
			Repaint();
	        return true;
        case ']':
            colorScale += 0.1f;
            STDERR(colorScale);
			DoUpdate();
            Repaint();
            return true;

        }

        return inherited::OnKey(key);
    }

    virtual void OnExit()
    {
        // clean up...
        delete particleSolver;
        particleSolver = 0;
        delete particleContainer;
        particleContainer = 0;
        delete particleRenderer;
        particleRenderer = 0;

        cudaDeviceReset();

#ifdef _WIN32
        if (waitForKey)
        {
            getch();
        }
#endif

        inherited::OnExit();
    }

    int parseArgs(int argc, char **argv)
    {	
	    // Parse command line arguments
	    for(int i = 1; i < argc;)
	    {
		    const char *szBuffer = argv[i++];
		
		    if (!strcasecmp("-help", szBuffer)) 
            {
			    printf("\n");
			    printf("NvParticles Test Application - Help\n\n");
			    printf("Program parameters:\n");
			    printf("\t-wait#\t\t\t: Wait for keypress before termination.\n");
			    printf("\t-iterations #\t\t\t: number of frames to benchmark (default 0 = no benchmarking).\n");			
                printf("\t-maxParticles #\t\t\t: maximum number of particles.\n");		
                printf("\t-test #\t\t\t: initial condition for test.\n");		
                printf("\t-device #\t\t\t: compute device index (default=0).\n");		
			    return 0;
		    }
		    else if (!strcasecmp("-wait", szBuffer)) 
            {
			    waitForKey = true;
		    }
            else if (!strcasecmp("-count", szBuffer)) 
            {
			    if(i == argc)
				    return -1;
			    szBuffer = argv[i++];		
			    maxParticles = atoi(szBuffer);
		    }
            else if(!strcasecmp("-iterations", szBuffer)) 
            {
			    if(i == argc)
				    return -1;
			    szBuffer = argv[i++];	
			    numBenchmarkIterations = atoi(szBuffer);
		    }
            else if(!strcasecmp("-device", szBuffer)) 
            {
			    if(i == argc)
				    return -1;
			    szBuffer = argv[i++];	
			    cudaDeviceIndex = atoi(szBuffer);
		    }
            else if(!strcasecmp("-test", szBuffer)) 
            {
			    if(i == argc)
				    return -1;
			    szBuffer = argv[i++];	
			    newTestMode = atoi(szBuffer);
		    }
		    else
			    return -1;
	    }

	    return 1;
    }
};


int main(int argc, char** argv)
{
    CudaParticlesApp app(argc, argv);
    app.Run();
}
