// (c) 2008 Steven Gratton
// Guided by examples from the AMD Stream SDK


#include <iostream>
#include <iomanip>
#include <string>

#include <ctime>

#include "cal.h"
#include "calcl.h"

#include "cholmultiprog.h"

std::string ILcheck=
"il_ps_2_0\n"
"ret_dyn\n"
"end\n";

using namespace std;

int main(int argc, char** argv)
{
  //make sure these are the same and a multiple of four!
  // or 256 until padding is fully understood!  
  CALint width=4096;
  CALint height=width;

  // these should be multiples of four for proper display

  CALint viewstartx=0;
  CALint viewstarty=0;  

  CALint viewwidth= 64;
  CALint viewheight=64;

  std::string kernel1 = choltopleft;
  //std::string kernel1 = ILcheck;
  std::string kernel2 = cholstrip;
  //std::string kernel2 = ILcheck;
  std::string kernel3 = choldiag;
  //std::string kernel3 = ILcheck;
  std::string kernel4 = cholhiup;
  //std::string kernel4 = ILcheck;




  calInit();
  CALuint numDevices = 0;
  calDeviceGetCount(&numDevices);

  cout << "Num devices =" << numDevices << endl;

  CALdevice device = 0;
  calDeviceOpen(&device, 0);

  CALdeviceinfo info;
  calDeviceGetInfo(&info, 0);

  CALcontext ctx = 0;
  calCtxCreate(&ctx, device);

  CALobject obj1 = NULL;
  CALimage image1 = NULL;
  CALlanguage lang1 = CAL_LANGUAGE_IL;



  if (calclCompile(&obj1, lang1, kernel1.c_str(), info.target) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel1 compilation failed. Exiting.\n");
      return 1;
    }
  else
    {
      cout << "kernel1 compiled fine" << endl;
    };
  if (calclLink(&image1, &obj1, 1) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel1 linking failed. Exiting.\n");
      return 1;
    }

  CALobject obj2 = NULL;
  CALimage image2 = NULL;
  CALlanguage lang2 = CAL_LANGUAGE_IL;


  if (calclCompile(&obj2, lang2, kernel2.c_str(), info.target) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel2 compilation failed. Exiting.\n");
      return 1;
    }
  else
    {
      cout << "kernel2 compiled fine" << endl;
    };
  if (calclLink(&image2, &obj2, 1) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel2 linking failed. Exiting.\n");
      return 1;
    }


  CALobject obj3 = NULL;
  CALimage image3 = NULL;
  CALlanguage lang3 = CAL_LANGUAGE_IL;





  if (calclCompile(&obj3, lang3, kernel3.c_str(), info.target) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel3 compilation failed. Exiting.\n");
      return 1;
    }
  else
    {
      cout << "kernel3 compiled fine" << endl;
    };
  if (calclLink(&image3, &obj3, 1) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel3 linking failed. Exiting.\n");
      return 1;
    }


  CALobject obj4 = NULL;
  CALimage image4 = NULL;
  CALlanguage lang4 = CAL_LANGUAGE_IL;



  if (calclCompile(&obj4, lang4, kernel4.c_str(), info.target) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel4 compilation failed. Exiting.\n");
      return 1;
    }
  else
    {
      cout << "kernel4 compiled fine" << endl;
    };

  if (calclLink(&image4, &obj4, 1) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel4 linking failed. Exiting.\n");
      return 1;
    }


  CALobject obj5 = NULL;
  CALimage image5 = NULL;
  CALlanguage lang5 = CAL_LANGUAGE_IL;

  std::string kernel5 = maketestmat;
  //std::string kernel5 = ILcheck;

  if (calclCompile(&obj5, lang5, kernel5.c_str(), info.target) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel5 compilation failed. Exiting.\n");
      return 1;
    }
  else
    {
      cout << "kernel5 compiled fine" << endl;
    };

  if (calclLink(&image5, &obj5, 1) != CAL_RESULT_OK)
    {
      fprintf(stdout, "Kernel5 linking failed. Exiting.\n");
      return 1;
    }


  cout << "after compiles..." << endl;


  CALresource globRes=0;
  if(calResAllocLocal2D(&globRes, device, width/4,height, 
			CAL_FORMAT_FLOAT_4, CAL_RESALLOC_GLOBAL_BUFFER)
     !=CAL_RESULT_OK) 
    {
      printf("Global resource allocation failed.\n");
    }
  else
    {
      cout << "global buffer fine." << endl;
    }

  // Constant resource

  //now trying a remote resource...
  CALresource constRes = 0;
  if(calResAllocLocal1D(&constRes, device, 1, CAL_FORMAT_INT_4, 0)
     //if(calResAllocRemote1D(&constRes, &device,1, 1, CAL_FORMAT_INT_4, 0)
     !=CAL_RESULT_OK)
    {
      cout << "Constant resource allocation failed." << endl;
    }
  else
    {
      cout << "Const buffer fine." << endl;
    }


  float* gdata=NULL;
  CALuint gpitch=0;
  CALmem globMem=0;
  calResMap((CALvoid**)&gdata,&gpitch,globRes,0);
  cout << "gpitch="<< gpitch << endl;
  for (int i = 0; i < height; ++i)
    {
      float* tmp = &gdata[i * gpitch*4];  // times 4 since float4's
      for (int j = 0; j < width; ++j)
	{
	  //  tmp[j] = (float) (height*i+j+1);
	  tmp[j] = 2.e-1f;
	}
    }

  /* for testing doubles...
     int* gtmp=(int*) gdata;
     gtmp[0]=0x00000000;gtmp[1]=0x401e0000;
     gtmp[2]=0x00000000;gtmp[3]=0x00000000;
  */
  calResUnmap(globRes);
  cout << "here" << endl;

  int* constPtr = NULL;
  CALuint constPitch = 0;
  CALmem constMem = 0;


  calCtxGetMem(&constMem, ctx, constRes);
  calCtxGetMem(&globMem, ctx, globRes);

  CALmodule module1 = 0;
  calModuleLoad(&module1, ctx, image1);
  CALfunc func1 = 0;

  CALname constName1=0;
  CALname globName1=0;

  calModuleGetEntry(&func1, ctx, module1, "main");
  calModuleGetName(&constName1, ctx, module1, "cb0");
  calModuleGetName(&globName1,ctx,module1,"g[]");

  calCtxSetMem(ctx, constName1, constMem);
  calCtxSetMem(ctx,globName1,globMem);


  CALmodule module2 = 0;
  calModuleLoad(&module2, ctx, image2);
  CALfunc func2 = 0;

  CALname constName2=0;
  CALname globName2=0;

  calModuleGetEntry(&func2, ctx, module2, "main");
  calModuleGetName(&constName2, ctx, module2, "cb0");
  calModuleGetName(&globName2,ctx,module2,"g[]");

  calCtxSetMem(ctx, constName2, constMem);
  calCtxSetMem(ctx,globName2,globMem);

  CALmodule module3 = 0;
  calModuleLoad(&module3, ctx, image3);
  CALfunc func3 = 0;

  CALname constName3=0;
  CALname globName3=0;

  calModuleGetEntry(&func3, ctx, module3, "main");
  calModuleGetName(&constName3, ctx, module3, "cb0");
  calModuleGetName(&globName3,ctx,module3,"g[]");

  calCtxSetMem(ctx, constName3, constMem);
  calCtxSetMem(ctx,globName3,globMem);

  CALmodule module4 = 0;
  calModuleLoad(&module4, ctx, image4);
  CALfunc func4 = 0;

  CALname constName4=0;
  CALname globName4=0;

  calModuleGetEntry(&func4, ctx, module4, "main");
  calModuleGetName(&constName4, ctx, module4, "cb0");
  calModuleGetName(&globName4,ctx,module4,"g[]");

  calCtxSetMem(ctx, constName4, constMem);
  calCtxSetMem(ctx,globName4,globMem);


  CALmodule module5 = 0;
  calModuleLoad(&module5, ctx, image5);
  CALfunc func5 = 0;

  CALname constName5=0;
  CALname globName5=0;

  calModuleGetEntry(&func5, ctx, module5, "main");
  calModuleGetName(&constName5, ctx, module5, "cb0");
  calModuleGetName(&globName5,ctx,module5,"g[]");

  calCtxSetMem(ctx, constName5, constMem);
  calCtxSetMem(ctx,globName5,globMem);



  CALevent e = 0;

  CALdomain domain5 = {0, 0, 2,(width/4+1)/2};
  calResMap((CALvoid**)&constPtr, &constPitch, constRes, 0);
  constPtr[0] = gpitch,     constPtr[1] = height/4;
  constPtr[2] = 0;     constPtr[3] = width/4;
  calResUnmap(constRes);

  cout << "running kernel5 ...";
  calCtxRunProgram(&e, ctx, func5, &domain5);
  cout << calGetErrorString();
  while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);
  cout << ": done kernel5" << endl;

  volatile clock_t gputime;
  gputime=clock();

  e=0;


  int pos=0;

  cout << "just before while loop" << endl;

  while (pos<(width/4-2))
    {
      e=0;
      CALdomain domain1 = {0, 0, 1,1};
      CALdomain domain2 = {0, 0, 2,(width/4-pos+1)/2};
      CALdomain domain3 = {0, 0, 2,(width/4-pos+1)/2};
      CALdomain domain4 = {0, 0, 2,(width/4-pos+1)/2};


      calResMap((CALvoid**)&constPtr, &constPitch, constRes, 0);
      constPtr[0] = gpitch,     constPtr[1] = height/4;
      constPtr[2] = pos;     constPtr[3] = width/4;
      calResUnmap(constRes);

      //cout <<"For pos=" << pos << ", domainheight="<< (width/4-pos+1)/2 <<", and ";

      calCtxRunProgram(&e, ctx, func1, &domain1);
      //while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);e=0;

      calCtxRunProgram(&e, ctx, func2, &domain2);
      // The following "while" seems to be crucial for result correctness
      // Is the DPP calculating multiple kernels simultaneously,
      // or starting the next one before finishing the global writes to
      // the previous one?
      while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);e=0;

      calCtxRunProgram(&e, ctx, func3, &domain3);
      while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);e=0;

      calCtxRunProgram(&e, ctx, func4, &domain4);
      while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);

      // cout << calGetErrorString()<<endl;
      pos++;
    }

  if(width/4>1)
    {
      e=0;
      CALdomain domain1 = {0, 0, 1,1};
      CALdomain domain2 = {0, 0, 1,1};
      CALdomain domain3 = {0, 0, 1,1};


      calResMap((CALvoid**)&constPtr, &constPitch, constRes, 0);
      constPtr[0] = gpitch,     constPtr[1] = height/4;
      constPtr[2] = pos;     constPtr[3] = width/4;
      calResUnmap(constRes);

      cout <<"For pos =" << pos << ":"<< endl;

      cout << "running kernel1 ...";
      calCtxRunProgram(&e, ctx, func1, &domain1);
      cout << calGetErrorString();
      cout << ": done kernel1" << endl;

      cout << "running kernel2 ...";
      calCtxRunProgram(&e, ctx, func2, &domain2);
      cout << calGetErrorString();
      cout << ": done kernel2" << endl;

      cout << "running kernel3 ...";
      calCtxRunProgram(&e, ctx, func3, &domain3);
      cout << calGetErrorString();
      cout << ": done kernel3" << endl;


      while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);
      pos++;
    }

  e=0;
  CALdomain domain1 = {0, 0, 1,1};
  calResMap((CALvoid**)&constPtr, &constPitch, constRes, 0);
  constPtr[0] = gpitch,     constPtr[1] = height/4;
  constPtr[2] = pos;     constPtr[3] = width/4;
  calResUnmap(constRes);

  cout <<"For pos =" << pos << ":"<< endl;

  cout << "running kernel1 ...";
  calCtxRunProgram(&e, ctx, func1, &domain1);
  cout << calGetErrorString();
  cout << ": done kernel1" << endl;
  while (calCtxIsEventDone(ctx, e) == CAL_RESULT_PENDING);


  gputime=clock()-gputime;

  cout << "gpu time=" << gputime/1.e6f <<" s." <<endl;



  calResMap((CALvoid**)&gdata, &gpitch, globRes, 0);
  /* simple plot... 
     for (int i = 0; i < viewheight; ++i)
     {
     float* tmp = &gdata[i * gpitch*4];
     for(int j = 0; j < viewwidth; ++j)
     {
     cout << setw(14) << tmp[j] << " ";
     }
     cout << endl;
     }
  */

  //clever plot...
  cout.precision(10);
  for (int i = viewstarty; i < (viewstarty+viewheight); i+=4)
    {
      float* tmp1 = &gdata[i * gpitch*4];
      float* tmp2 = &gdata[(i+1) * gpitch*4];
      float* tmp3 = &gdata[(i+2) * gpitch*4];
      float* tmp4 = &gdata[(i+3) * gpitch*4];
      int j;
      for(j = 4*viewstartx; j < 4*(viewstartx+viewwidth); j+=4)
	{
	  if (j<4*gpitch) {cout  << setw(14) << tmp1[j] << " ";} else {
	    if (j<8*gpitch) {cout << setw(14) << tmp2[j-4*gpitch] << " ";} else{
	      if (j<12*gpitch) {cout << setw(14) << tmp3[j-8*gpitch] << " ";} else{
		if (j<16*gpitch) {cout << setw(14) << tmp4[j-12*gpitch] << " ";}}}};
	  if (j%16==12) cout << "   ";
	}
      cout << endl;
      for(j = 4*viewstartx; j < 4*(viewstartx+viewwidth); j+=4)
	{
	  if (j<4*gpitch) {cout  << setw(14) << tmp1[j+1] << " ";} else {
	    if (j<8*gpitch) {cout << setw(14) << tmp2[j-4*gpitch+1] << " ";} else{
	      if (j<12*gpitch) {cout << setw(14) << tmp3[j-8*gpitch+1] << " ";} else{
		if (j<16*gpitch) {cout << setw(14) << tmp4[j-12*gpitch+1] << " ";}}}};
	  if (j%16==12) cout << "   ";
	}
      cout << endl;
      for(j = 4*viewstartx; j < 4*(viewstartx+viewwidth); j+=4)
	{
	  if (j<4*gpitch) {cout << setw(14) << tmp1[j+2] << " ";} else {
	    if (j<8*gpitch) {cout << setw(14) << tmp2[j-4*gpitch+2] << " ";} else{
	      if (j<12*gpitch) {cout << setw(14) << tmp3[j-8*gpitch+2] << " ";} else{
		if (j<16*gpitch) {cout << setw(14) << tmp4[j-12*gpitch+2] << " ";}}}};
	  if (j%16==12) cout << "   ";
	}
      cout << endl; 
      for(j = 4*viewstartx; j < 4*(viewstartx+viewwidth); j+=4)
	{
	  if (j<4*gpitch) {cout << setw(14) << tmp1[j+3] << " ";} else {
	    if (j<8*gpitch) {cout << setw(14) << tmp2[j-4*gpitch+3] << " ";} else{
	      if (j<12*gpitch) {cout << setw(14) << tmp3[j-8*gpitch+3] << " ";} else{
		if (j<16*gpitch) {cout << setw(14) << tmp4[j-12*gpitch+3] << " ";}}}};
	  if (j%16==12) cout << "   ";}
      cout << endl;
      cout << endl;
    }  
  // for testing doubles...
  //int* tmpptr= (int*) &gdata[0];
  //cout << hex << tmpptr[0] << " " << tmpptr[1] << " " << tmpptr[2] << " " << tmpptr[3] << endl;


  calResUnmap(globRes);

  calModuleUnload(ctx, module1);
  calModuleUnload(ctx, module2);

  // Freeing compiled kernel binary
  calclFreeImage(image1);
  calclFreeObject(obj1);
  calclFreeImage(image2);
  calclFreeObject(obj2);

  // Releasing resource from context
  calCtxReleaseMem(ctx, constMem);
  calCtxReleaseMem(ctx, globMem);

  // Deallocating resources
  calResFree(constRes);
  calResFree(globRes);

  // Destroying context
  calCtxDestroy(ctx);

  // Closing device
  calDeviceClose(device);

  // Shutting down CAL
  calShutdown();

  return 0;
}