/* ---------------------------------------------------------------------------
 * This software is in the public domain, furnished "as is", without technical
 * support, and with no warranty, express or implied, as to its usefulness for
 * any purpose.

 * Author: Wil Braithwaite.
 *
 */

#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H

#include <cuda_runtime_api.h>
#include <cstdio>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include "math_utils.h"
#include <cutil_math.h>
#include <iostream>
#include "radixsort.cuh"

namespace Easy
{

//---------------------------------------------------------------------------------------
#define NVPARTICLES_CUDA_SAFE_CALL(err) Easy::__mycudaSafeCall(err, __FILE__, __LINE__)

inline bool __mycudaSafeCall( cudaError err, const char *file, const int line )
{
	if( cudaSuccess != err) \
    {
        const char* s = cudaGetErrorString( err);
        fprintf(stderr, "cudaSafeCall() Runtime API error (%d) in file <%s>, line %i : %s.\n",
			err, file, line, s);

#if defined(_DEBUG)
        abort();
#endif
        return false;
    }
	return true;
}

//---------------------------------------------------------------------------------------
#define NVPARTICLES_CUDA_CHECK_ERROR(msg)           Easy::__mycutilCheckMsg     (msg, __FILE__, __LINE__)

inline bool __mycutilCheckMsg( const char *errorMessage, const char *file, const int line )
{
#ifdef _DEBUG
	cudaThreadSynchronize();
#endif
    cudaError_t err = cudaGetLastError();
    if( cudaSuccess != err) {
        fprintf(stderr, "cutilCheckMsg cudaThreadSynchronize error: %s in file <%s>, line %i : %s.\n",
                errorMessage, file, line, cudaGetErrorString( err) );
        abort();
        return false;
    }
    return true;
}

//---------------------------------------------------------------------------------------
#define EASY_CUDA_CHECK_ERROR_SYNC(msg)           cudaDeviceSynchronize(); Easy::__mycutilCheckMsg     (msg, __FILE__, __LINE__)

//---------------------------------------------------------------------------------------
namespace Cu
{

//---------------------------------------------------------------------------------------
class Stream
{
    cudaStream_t _stream;
    bool _owner;

public:
    Stream(cudaStream_t s)
        :
        _stream(s), _owner(false)
    {
    }

    Stream()
        :
        _owner(true)
    {
        assert(NVPARTICLES_CUDA_SAFE_CALL(cudaStreamCreate(&_stream)));
    }

    ~Stream()
    {
        if (_owner)
            assert(NVPARTICLES_CUDA_SAFE_CALL(cudaStreamDestroy(_stream)));
    }

    operator cudaStream_t()
    {
        return _stream;
    }

    operator size_t()
    {
        return (size_t)_stream;
    }

    bool isComplete()
    {
        if(cudaStreamQuery(_stream) == cudaSuccess)
            return true;
        return false;
    }

    void sync()
    {
        //std::cout << "syncing stream: " << (long)stream << std::endl;
        NVPARTICLES_CUDA_SAFE_CALL(cudaStreamSynchronize(_stream));
    }
};

//---------------------------------------------------------------------------------------
class ScopedStream
{
	Stream s;

public:
    inline ScopedStream(cudaStream_t _s=0)
            :
            s(_s)
    {
    }

    inline ~ScopedStream()
    {
        s.sync();
    }
};

//---------------------------------------------------------------------------------------
class ScopedTimer
{
    static int g_indent;
    cudaEvent_t start[100], stop[100];
    char message[256];

    int *indent;
	bool endCalled, beginCalled;
	cudaStream_t stream;
	float totalTime;
	int iterations;
	int numEvents;

public:
    inline ScopedTimer(const char *m, cudaStream_t s=0, bool automatic=true, int _numEvents=100, int *indent_storage=NULL)
            :
            indent(indent_storage), endCalled(false), beginCalled(false), stream(s), totalTime(0.f), iterations(0), numEvents(_numEvents)
    {
        if (!indent)
            indent = &g_indent;
        for(int i=0;i<numEvents;++i)
        {
            NVPARTICLES_CUDA_SAFE_CALL(cudaEventCreate(&start[i]));
            NVPARTICLES_CUDA_SAFE_CALL(cudaEventCreate(&stop[i]));
        }
        strcpy(message,m);
        if(automatic)
            Begin();
    }

    void Begin()
    {
        if(beginCalled)
            return;
        ++(*indent);
        beginCalled = true;
        endCalled = false;
        NVPARTICLES_CUDA_SAFE_CALL(cudaEventRecord( start[(iterations)%numEvents], stream ));

        //printf("%p begin %d\n", this, iterations);
    }

    void End()
    {
        if(!beginCalled)
            return;
        //if(endCalled)
        //    return;
        beginCalled = false;

        NVPARTICLES_CUDA_SAFE_CALL(cudaEventRecord( stop[(iterations)%numEvents], stream ));

        //printf("%p end %d\n", this, iterations);

        ++iterations;
        --(*indent);
        endCalled = true;
    }

    inline ~ScopedTimer()
    {
        Print();

        for(int i=0;i<numEvents;++i)
        {
            NVPARTICLES_CUDA_SAFE_CALL(cudaEventDestroy(start[i]));
            NVPARTICLES_CUDA_SAFE_CALL(cudaEventDestroy(stop[i]));
        }
    }

    void Print()
    {
        if(beginCalled)
            End();

        if(endCalled)
        {
            cudaError_t err;

            int usedEvents = min(numEvents, iterations);
            totalTime = 0;
            for(int i=0;i<usedEvents;++i)
            {
                if( cudaSuccess != (err=cudaEventSynchronize( stop[i] )))
                {
                    fprintf(stdout, "%s: Runtime API error : %s.\n", (const char *)message, cudaGetErrorString( err));
                    exit(-1);
                }
                float time;
                NVPARTICLES_CUDA_SAFE_CALL(cudaEventElapsedTime( &time, start[i], stop[i] ));
                totalTime += time;
            }
            totalTime /= usedEvents;

            for (int i=(*indent);i>0;--i)
                printf("\t");
            printf("%08.6fs (%d): %s\n",totalTime/(1000.f), min(numEvents, iterations),(const char *)message);

            // start again!
            iterations = 0;
        }


    }
};

//---------------------------------------------------------------------------------------
} // end namespace Cu

//---------------------------------------------------------------------------------------
/// Round a / b to nearest higher integer value
inline uint iDivUp(uint a, uint b)
{
    if(b == 0)
        return a;
    return (a % b != 0) ? (a / b + 1) : (a / b);
}

//---------------------------------------------------------------------------------------
/// compute grid and thread block size for a given number of elements
inline void computeGridSize(uint n, uint blockSize, uint &numBlocks, uint &numThreads)
{
#ifndef _MSC_VER
    numThreads = std::min(blockSize, n);
#else
    numThreads = min(blockSize, n);
#endif
    numBlocks = iDivUp(n, numThreads);
}

//---------------------------------------------------------------------------------------
}

#endif
