Commit 0bb97466 authored by Stefan's avatar Stefan
Browse files

update linAlg

parent b6847dbd
......@@ -47,8 +47,7 @@ void linAlg_t::setup() {
//add defines
kernelInfo["defines/" "p_blockSize"] = blocksize;
reallocBuffers(BLOCKSIZE * sizeof(dfloat));
reallocBuffers(blocksize * sizeof(dfloat));
string oklDir;
oklDir.assign(getenv("NEKRS_INSTALL_DIR"));
......@@ -81,21 +80,46 @@ void linAlg_t::setup() {
"linAlgAXPBY.okl",
"axpbyz",
kernelInfo);
if (axpbyzManyKernel.isInitialized()==false)
axpbyzManyKernel = device.buildKernel(oklDir +
"linAlgAXPBY.okl",
"axpbyzMany",
kernelInfo);
if (axmyKernel.isInitialized()==false)
axmyKernel = device.buildKernel(oklDir +
"linAlgAXMY.okl",
"axmy",
kernelInfo);
if (axmyManyKernel.isInitialized()==false)
axmyManyKernel = device.buildKernel(oklDir +
"linAlgAXMY.okl",
"axmyMany",
kernelInfo);
if (axmyzKernel.isInitialized()==false)
axmyzKernel = device.buildKernel(oklDir +
"linAlgAXMY.okl",
"axmyz",
kernelInfo);
if (adyKernel.isInitialized()==false)
adyKernel = device.buildKernel(oklDir +
"linAlgAXDY.okl",
"ady",
kernelInfo);
if (axdyKernel.isInitialized()==false)
axdyKernel = device.buildKernel(oklDir +
"linAlgAXDY.okl",
"axdy",
kernelInfo);
if (aydxKernel.isInitialized()==false)
aydxKernel = device.buildKernel(oklDir +
"linAlgAXDY.okl",
"aydx",
kernelInfo);
if (aydxManyKernel.isInitialized()==false)
aydxManyKernel = device.buildKernel(oklDir +
"linAlgAXDY.okl",
"aydxMany",
kernelInfo);
if (axmyzKernel.isInitialized()==false)
axmyzKernel = device.buildKernel(oklDir +
"linAlgAXDY.okl",
......@@ -147,9 +171,14 @@ linAlg_t::~linAlg_t() {
scaleKernel.free();
axpbyKernel.free();
axpbyzKernel.free();
axpbyzManyKernel.free();
axmyKernel.free();
axmyManyKernel.free();
axmyzKernel.free();
axdyKernel.free();
aydxKernel.free();
aydxManyKernel.free();
adyKernel.free();
axdyzKernel.free();
sumKernel.free();
minKernel.free();
......@@ -190,12 +219,21 @@ void linAlg_t::axpbyz(const dlong N, const dfloat alpha, occa::memory& o_x,
const dfloat beta, occa::memory& o_y, occa::memory& o_z) {
axpbyzKernel(N, alpha, o_x, beta, o_y, o_z);
}
void linAlg_t::axpbyzMany(const dlong N, const dlong Nfields, const dlong fieldOffset, const dfloat alpha, occa::memory& o_x,
const dfloat beta, occa::memory& o_y, occa::memory& o_z) {
axpbyzManyKernel(N, Nfields, fieldOffset, alpha, o_x, beta, o_y, o_z);
}
// o_y[n] = alpha*o_x[n]*o_y[n]
void linAlg_t::axmy(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y) {
axmyKernel(N, alpha, o_x, o_y);
}
void linAlg_t::axmyMany(const dlong N, const dlong Nfields, const dlong offset,
const dlong mode, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y) {
axmyManyKernel(N, Nfields, offset, mode, alpha, o_x, o_y);
}
// o_z[n] = alpha*o_x[n]*o_y[n]
void linAlg_t::axmyz(const dlong N, const dfloat alpha,
......@@ -208,6 +246,20 @@ void linAlg_t::axdy(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y) {
axdyKernel(N, alpha, o_x, o_y);
}
void linAlg_t::aydx(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y) {
aydxKernel(N, alpha, o_x, o_y);
}
void linAlg_t::aydxMany(const dlong N, const dlong Nfields, const dlong fieldOffset,
const dlong mode, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y) {
aydxManyKernel(N, Nfields, fieldOffset, mode, alpha, o_x, o_y);
}
// o_y[n] = alpha/o_y[n]
void linAlg_t::ady(const dlong N, const dfloat alpha,
occa::memory& o_y) {
adyKernel(N, alpha, o_y);
}
// o_z[n] = alpha*o_x[n]/o_y[n]
void linAlg_t::axdyz(const dlong N, const dfloat alpha,
......@@ -217,13 +269,13 @@ void linAlg_t::axdyz(const dlong N, const dfloat alpha,
// \sum o_a
dfloat linAlg_t::sum(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
int Nblock = (N+blocksize-1)/blocksize;
const dlong Nbytes = Nblock * sizeof(dfloat);
if(o_scratch.size() < Nbytes) reallocBuffers(Nbytes);
sumKernel(Nblock, N, o_a, o_scratch);
o_scratch.copyTo(scratch, Nblock*sizeof(dfloat));
o_scratch.copyTo(scratch, Nbytes);
dfloat sum = 0;
for(dlong n=0;n<Nblock;++n){
......@@ -238,16 +290,16 @@ dfloat linAlg_t::sum(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
// \min o_a
dfloat linAlg_t::min(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
int Nblock = (N+blocksize-1)/blocksize;
const dlong Nbytes = Nblock * sizeof(dfloat);
if(o_scratch.size() < Nbytes) reallocBuffers(Nbytes);
minKernel(Nblock, N, o_a, o_scratch);
o_scratch.copyTo(scratch, Nblock*sizeof(dfloat));
o_scratch.copyTo(scratch, Nbytes);
dfloat min = 9e30;
for(dlong n=0;n<Nblock;++n){
dfloat min = scratch[0];
for(dlong n=1;n<Nblock;++n){
min = (scratch[n] < min) ? scratch[n]:min;
}
......@@ -258,16 +310,16 @@ dfloat linAlg_t::min(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
// \max o_a
dfloat linAlg_t::max(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
int Nblock = (N+blocksize-1)/blocksize;
const dlong Nbytes = Nblock * sizeof(dfloat);
if(o_scratch.size() < Nbytes) reallocBuffers(Nbytes);
maxKernel(Nblock, N, o_a, o_scratch);
o_scratch.copyTo(scratch, Nblock*sizeof(dfloat));
o_scratch.copyTo(scratch, Nbytes);
dfloat max = -9e30;
for(dlong n=0;n<Nblock;++n){
dfloat max = scratch[0];
for(dlong n=1;n<Nblock;++n){
max = (scratch[n] > max) ? scratch[n]:max;
}
......@@ -280,8 +332,8 @@ dfloat linAlg_t::max(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
// ||o_a||_2
/*
dfloat linAlg_t::norm2(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
Nblock = (Nblock>BLOCKSIZE) ? BLOCKSIZE : Nblock; //limit to BLOCKSIZE entries
int Nblock = (N+blocksize-1)/blocksize;
Nblock = (Nblock>blocksize) ? blocksize : Nblock; //limit to blocksize entries
norm2Kernel(Nblock, N, o_a, o_scratch);
......@@ -302,13 +354,13 @@ dfloat linAlg_t::norm2(const dlong N, occa::memory& o_a, MPI_Comm _comm) {
// o_x.o_y
dfloat linAlg_t::innerProd(const dlong N, occa::memory& o_x, occa::memory& o_y,
MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
int Nblock = (N+blocksize-1)/blocksize;
const dlong Nbytes = Nblock * sizeof(dfloat);
if(o_scratch.size() < Nbytes) reallocBuffers(Nbytes);
innerProdKernel(Nblock, N, o_x, o_y, o_scratch);
o_scratch.copyTo(scratch, Nblock*sizeof(dfloat));
o_scratch.copyTo(scratch, Nbytes);
dfloat dot = 0;
for(dlong n=0;n<Nblock;++n){
......@@ -325,13 +377,13 @@ dfloat linAlg_t::innerProd(const dlong N, occa::memory& o_x, occa::memory& o_y,
dfloat linAlg_t::weightedInnerProd(const dlong N, occa::memory& o_w,
occa::memory& o_x, occa::memory& o_y,
MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
int Nblock = (N+blocksize-1)/blocksize;
const dlong Nbytes = Nblock * sizeof(dfloat);
if(o_scratch.size() < Nbytes) reallocBuffers(Nbytes);
weightedInnerProdKernel(Nblock, N, o_w, o_x, o_y, o_scratch);
o_scratch.copyTo(scratch, Nblock*sizeof(dfloat));
o_scratch.copyTo(scratch, Nbytes);
dfloat dot = 0;
for(dlong n=0;n<Nblock;++n){
......@@ -347,13 +399,13 @@ dfloat linAlg_t::weightedInnerProd(const dlong N, occa::memory& o_w,
// ||o_a||_w2
dfloat linAlg_t::weightedNorm2(const dlong N, occa::memory& o_w,
occa::memory& o_a, MPI_Comm _comm) {
int Nblock = (N+BLOCKSIZE-1)/BLOCKSIZE;
int Nblock = (N+blocksize-1)/blocksize;
const dlong Nbytes = Nblock * sizeof(dfloat);
if(o_scratch.size() < Nbytes) reallocBuffers(Nbytes);
weightedNorm2Kernel(Nblock, N, o_w, o_a, o_scratch);
o_scratch.copyTo(scratch, Nblock*sizeof(dfloat));
o_scratch.copyTo(scratch, Nbytes);
dfloat norm = 0;
for(dlong n=0;n<Nblock;++n){
......
......@@ -46,10 +46,10 @@ private:
void setup();
void reallocBuffers(const dlong Nbytes);
public:
linAlg_t(occa::device& _device, occa::properties*& _kernelInfo, MPI_Comm& _comm) {
linAlg_t(occa::device& _device, occa::properties& _kernelInfo, MPI_Comm& _comm) {
blocksize = BLOCKSIZE;
device = _device;
kernelInfo = *(_kernelInfo);
kernelInfo = _kernelInfo;
comm = _comm;
setup();
}
......@@ -77,16 +77,31 @@ public:
void axpbyz(const dlong N, const dfloat alpha, occa::memory& o_x,
const dfloat beta, occa::memory& o_y,
occa::memory& o_z);
void axpbyzMany(const dlong N, const dlong Nfields, const dlong offset, const dfloat alpha, occa::memory& o_x,
const dfloat beta, occa::memory& o_y,
occa::memory& o_z);
// o_y[n] = alpha*o_x[n]*o_y[n]
void axmy(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y);
// mode 1:
// o_y[n,fld] = alpha*o_x[n,fld]*o_y[n,fld]
// mode 0:
// o_y[n,fld] = alpha*o_x[n]*o_y[n,fld]
void axmyMany(const dlong N,
const dlong Nfields,
const dlong offset, const dlong mode,
const dfloat alpha,
occa::memory& o_x, occa::memory& o_y);
// o_z[n] = alpha*o_x[n]*o_y[n] (new)
void axmyz(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y,
occa::memory& o_z);
// o_y[n] = alpha/o_y[n]
void ady(const dlong N, const dfloat alpha,
occa::memory& o_y);
// o_y[n] = alpha*o_x[n]/o_y[n]
void axdy(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y);
......@@ -95,6 +110,12 @@ public:
void axdyz(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y,
occa::memory& o_z);
// o_y[n] = alpha*o_y[n]/o_x[n]
void aydx(const dlong N, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y);
void aydxMany(const dlong N, const dlong Nfields, const dlong fieldOffset,
const dlong mode, const dfloat alpha,
occa::memory& o_x, occa::memory& o_y);
// \sum o_a
dfloat sum(const dlong N, occa::memory& o_a, MPI_Comm _comm);
......@@ -125,9 +146,14 @@ public:
occa::kernel scaleKernel;
occa::kernel axpbyKernel;
occa::kernel axpbyzKernel;
occa::kernel axpbyzManyKernel;
occa::kernel axmyKernel;
occa::kernel axmyManyKernel;
occa::kernel axmyzKernel;
occa::kernel axdyKernel;
occa::kernel aydxKernel;
occa::kernel aydxManyKernel;
occa::kernel adyKernel;
occa::kernel axdyzKernel;
occa::kernel sumKernel;
occa::kernel minKernel;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment