/*
 * Copyright 1993-2012 NVIDIA Corporation.  All rights reserved.
 *
 * Please refer to the NVIDIA end user license agreement (EULA) associated
 * with this source code for terms and conditions that govern your use of
 * this software. Any use, reproduction, disclosure, or distribution of
 * this software and related documentation outside the terms of the EULA
 * is strictly prohibited.
 *
 */

#include "NvParticlesCollisionMathInline.h"

#define NVPARTICLES_USE_TRANSFORM_SLERP // use the slerp!

//------------------------------------------------------------------------------------------

#ifdef __CUDA_ARCH__
#define NVPARTICLES_PRIMITIVE_PARAM(x) d_primitivesParams.x
#else
#define NVPARTICLES_PRIMITIVE_PARAM(x) h_primitivesParams.x
#endif

__constant__ PrimitiveGroup d_primitivesParams;
static PrimitiveGroup h_primitivesParams;

//------------------------------------------------------------------------------------------
//! update constants.
//! NB. This failed on linux if we use extern "C". (Let's work out why sometime!)
//!
inline
void uploadPrimitives(const PrimitiveGroup& params, cudaStream_t stream=0)
{
    h_primitivesParams = params;

#ifdef __CUDACC__
    NVPARTICLES_CUDA_SAFE_CALL( cudaMemcpyToSymbolAsync(d_primitivesParams, &h_primitivesParams, sizeof(PrimitiveGroup), 0, cudaMemcpyHostToDevice, stream ));
#endif

    // initialize the current xform to point to the local xform in the params.
    for(int i=0;i<NVPARTICLES_PRIMITIVE_MAX_COUNT;++i)
        h_primitivesParams.transforms[i].xform = &h_primitivesParams.primitives[i].xform;
}

//------------------------------------------------------------------------------------------

NVPARTICLES_CUDA_EXPORT inline
float3 vec3fToFloat3(const vec3f& P)
{
    return make_float3(P.x,P.y,P.z);
}

#define SIGN(x) ((x>0)?1:-1)

//------------------------------------------------------------------------------------------
/// Evaluate a collision with a primitive.
///
/// @param	P				world-space position of the point.
/// @param	V				internal-space velocity of the point.
/// @param	c				primitive to test against.
/// @param	outContact		[out] nearest point on the surface of the primitive.
/// @param	outNormal		[out] normal of the nearest point on the surface of the primitive.
///
/// @return					internal-space penetration value, (0 = no collision, +ve = within object, -ve = outside object)
///
inline static NVPARTICLES_CUDA_EXPORT
float collisionEvaluate(vec3f& outContact, vec3f& outNormal, float4 P, float3 V, const Primitive& c)
{

    //float s = (c.flags&Primitive::PRIMITIVE_FLAGS_INTERIOR)?-1:+1;
    vec3f P3 = make_vec3f(P.x,P.y,P.z);
    outContact = P3;

    if(c.type == Primitive::PRIMITIVE_NONE)
        return 0;

#if defined(NVPARTICLES_USE_TRANSFORM_SLERP)
    // get the current transformation matrix from the transform-pool...
    PrimitiveTransform transformData = NVPARTICLES_PRIMITIVE_PARAM(transforms[c.transformIndex]);
    mat44f xform = *transformData.xform;
#else
	mat44f xform = c.xform;
#endif

    mat44f xformInv = xform.inverseAffine();

    float penetration = 0;

    if(c.type == Primitive::PRIMITIVE_SPHERE)
    {
        if (c.flags&Primitive::PRIMITIVE_FLAGS_INTERIOR)
        {
            penetration = Collision::UnitSphereInt(P3, NVPARTICLES_PRIMITIVE_PARAM(particleRadius), xform, xformInv, outContact, outNormal);
        }
        else
        {
            penetration = Collision::UnitSphereExt(P3, NVPARTICLES_PRIMITIVE_PARAM(particleRadius), xform, xformInv, outContact, outNormal);
        }
    }
    else if(c.type == Primitive::PRIMITIVE_PLANE)
    {
        if(c.flags&Primitive::PRIMITIVE_FLAGS_INTERIOR)
        {
            // not supported!
        }
        else
        {
            penetration = Collision::PlaneExt(P3, make_vec3f(V), NVPARTICLES_PRIMITIVE_PARAM(particleRadius),
                                                  xform, xformInv,
                                                  c.extents,
                                                  outContact, outNormal);
        }
    }
    else if(c.type == Primitive::PRIMITIVE_CAPSULE)
    {
        if(c.flags&Primitive::PRIMITIVE_FLAGS_INTERIOR)
        {
            // not supported!
        }
        else
        {
            penetration = Collision::UnitCapsuleExt(P3, NVPARTICLES_PRIMITIVE_PARAM(particleRadius),
                                                  c.extents.z,
                                                  xform, xformInv,
                                                  outContact, outNormal);
        }
    }
    else if(c.type == Primitive::PRIMITIVE_BOX)
    {
        if(c.flags&Primitive::PRIMITIVE_FLAGS_INTERIOR)
        {
            penetration = Collision::UnitBoxInt(P3, NVPARTICLES_PRIMITIVE_PARAM(particleRadius),
                                                    xform, xformInv,
                                                    outContact, outNormal);
        }
        else
        {
            penetration = Collision::UnitBoxExt(P3, NVPARTICLES_PRIMITIVE_PARAM(particleRadius),
                                                    xform, xformInv,
                                                    outContact, outNormal);
        }
    }

    // ignore negative or tiny penetrations...
    if (penetration < NVPARTICLES_EPSILON)
    {
        penetration = 0;
        // for safety (just in case my routines didn't do it properly!)
        outContact = P3;
    }

    return penetration * NVPARTICLES_PRIMITIVE_PARAM(internalScale);
}

//------------------------------------------------------------------------------------------
/// Collision response - Project & Reflection impulse method.
/// From "Mickey Kelager's 2006 paper"
///
NVPARTICLES_CUDA_EXPORT inline static
float3 collisionResponse_ProjectReflect(float4& outP, float3& outV, float penetration, vec3f contact, vec3f normal)
{
    // project P outside of object.
    outP = make_float4(vec3fToFloat3(contact), outP.w);

    // reflect V (modulated by penetration depth).
    float3 N = vec3fToFloat3(normal);
    float3 Vn = dot(N, outV) * N;
    float Vlen = length(outV);
    if(Vlen != 0)
        outV += -( 1 + NVPARTICLES_PRIMITIVE_PARAM(restitution) * ((penetration) / (NVPARTICLES_PRIMITIVE_PARAM(deltaTime)*Vlen)) ) * Vn;

    return make_float3(0);
}

//------------------------------------------------------------------------------------------
/// Collision response - penalty-force method.
/// from ("Real-time particle-based fluid simulation with rigid body interaction" - Amada T.)
///
NVPARTICLES_CUDA_EXPORT inline static
float3 collisionResponse_PenaltyForce(float4& outP, float3& outV, float penetration, vec3f contact, vec3f normal)
{
    float3 N = vec3fToFloat3(normal);
    //float3 force = (stiffness*penetration*N) - restitution*Vn - friction*Vt

    // penalty:
    float3 force = (NVPARTICLES_PRIMITIVE_PARAM(stiffness) * penetration - NVPARTICLES_PRIMITIVE_PARAM(damping) * dot(outV, N)) * N;


/*
    float3 Vn = dot(outV, N) * N;
    float3 Vt = outV - Vn;
    // penalty.
    float3 force = (NVPARTICLES_PRIMITIVE_PARAM(stiffness) * penetration) * N;
    // restitution.
    /// surely this should be (1+restitution)*Vn to make a reflection? (or is this force handled by penalty???)
    force += - NVPARTICLES_PRIMITIVE_PARAM(restitution) * Vn;
    // friction.
    force += - NVPARTICLES_PRIMITIVE_PARAM(friction) * Vt;
*/
    return force;
}

//------------------------------------------------------------------------------------------
/// Default primitive iterator.
///
template<int METHOD>
struct DefaultCollisionIterator : SimpleIterator
{
    float4 P;
    float3 V;
    float4 outP;
    float3 outV;
    float3 outForce;
    int hits;

    inline NVPARTICLES_CUDA_EXPORT static
    void pre(DefaultCollisionIterator& it, const uint &i)
    {
        it.outForce = make_float3(0);
        it.outV = it.V;
        it.outP = it.P;
		it.hits = 0;
    }

    inline NVPARTICLES_CUDA_EXPORT static
    bool item(DefaultCollisionIterator& it, const uint &i)
    {
        const Primitive& c = NVPARTICLES_PRIMITIVE_PARAM(primitives[it.__iteratorIndex]);
        vec3f contact, normal;
        float penetration = collisionEvaluate(contact, normal, it.outP, it.outV, c);

        if (penetration > 0)
        {
            if (METHOD == 0)
            {
                it.outForce += collisionResponse_ProjectReflect(it.outP, it.outV, penetration, contact, normal);
                ++it.hits;
            }
            else if (METHOD == 1)
            {
                it.outForce += collisionResponse_PenaltyForce(it.outP, it.outV, penetration, contact, normal);
                ++it.hits;
            }
        }

        return true;
    }

    inline NVPARTICLES_CUDA_EXPORT static
    void post(DefaultCollisionIterator& it, const uint &i)
    {
    }
};

//------------------------------------------------------------------------------------------
/// Iterator to test for any collision.
///
struct CollisionTestIterator : SimpleIterator
{
    float4 P;
    int hits;

    inline NVPARTICLES_CUDA_EXPORT static
    void pre(CollisionTestIterator& it, const uint &i)
    {
        it.hits = 0;
    }

    inline NVPARTICLES_CUDA_EXPORT static
    bool item(CollisionTestIterator& it, const uint &i)
    {
        const Primitive& c = NVPARTICLES_PRIMITIVE_PARAM(primitives[it.__iteratorIndex]);
        vec3f contact, normal;
        float penetration = collisionEvaluate(contact, normal, it.P, make_float3(0), c);

        if(penetration > 0)
        {
            ++it.hits;
            return false;
        }

        return true;
    }
};

//------------------------------------------------------------------------------------------
template<class Iterator>
inline NVPARTICLES_CUDA_EXPORT void iteratePrimitives(Iterator& data, const uint& i)
{
    /// which is faster?
    // a) unroll the loop.
    // b) trim the iteration count.

    //simpleIterate<Iterator, NVPARTICLES_PRIMITIVE_MAX_COUNT>(data, i);
    simpleIterate<Iterator>(data, i, NVPARTICLES_PRIMITIVE_PARAM(numPrimitives));
}

//------------------------------------------------------------------------------------------
