diff --git a/.gitignore b/.gitignore index ddc9ab3..c7b283a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,6 @@ logs results npm-debug.log +package-lock.json node_modules/* *.DS_Store \ No newline at end of file diff --git a/fdg.js b/fdg.js index 8db7e43..69e2ce5 100644 --- a/fdg.js +++ b/fdg.js @@ -1,9 +1,99 @@ 'use strict' -module.exports = gradient +var dup = require('dup') + +var CACHED_CWiseOp = { + zero: function(SS, a0, t0, p0) { + var s0 = SS[0], t0p0 = t0[0] + p0 |= 0 + var i0 = 0, d0s0 = t0p0 + for (i0 = 0; i0 < s0; ++i0) { + a0[p0] = 0 + p0 += d0s0 + } + }, + + fdTemplate1: function(SS, a0, t0, p0, a1, t1, p1) { + var s0 = SS[0], t0p0 = t0[0], t1p0 = t1[0], q0 = -1 * t0p0, q1 = t0p0 + p0 |= 0 + p1 |= 0 + var i0 = 0, d0s0 = t0p0, d1s0 = t1p0 + for (i0 = 0; i0 < s0; ++i0) { + a1[p1] = 0.5 * (a0[p0 + q0] - a0[p0 + q1]) + p0 += d0s0 + p1 += d1s0 + } + }, + + fdTemplate2: function(SS, a0, t0, p0, a1, t1, p1, a2, t2, p2) { + var s0 = SS[0], s1 = SS[1], t0p0 = t0[0], t0p1 = t0[1], t1p0 = t1[0], t1p1 = t1[1], t2p0 = t2[0], t2p1 = t2[1], q0 = -1 * t0p0, q1 = t0p0, q2 = -1 * t0p1, q3 = t0p1 + p0 |= 0 + p1 |= 0 + p2 |= 0 + var i0 = 0, i1 = 0, d0s0 = t0p1, d0s1 = (t0p0 - s1 * t0p1), d1s0 = t1p1, d1s1 = (t1p0 - s1 * t1p1), d2s0 = t2p1, d2s1 = (t2p0 - s1 * t2p1) + for (i1 = 0; i1 < s0; ++i1) { + for (i0 = 0; i0 < s1; ++i0) { + a1[p1] = 0.5 * (a0[p0 + q0] - a0[p0 + q1]); a2[p2] = 0.5 * (a0[p0 + q2] - a0[p0 + q3]) + p0 += d0s0 + p1 += d1s0 + p2 += d2s0 + } + p0 += d0s1 + p1 += d1s1 + p2 += d2s1 + } + } +} + +var CACHED_thunk = { + cdiff: function(compile) { + var CACHED = {} + return function cdiff_cwise_thunk(array0, array1, array2) { + var t0 = array0.dtype, r0 = array0.order, t1 = array1.dtype, r1 = array1.order, t2 = array2.dtype, r2 = array2.order, type = [t0, r0.join(), t1, r1.join(), t2, r2.join()].join(), proc = CACHED[type] + if (!proc) { CACHED[type] = proc = compile([t0, r0, t1, r1, t2, r2]) } return proc(array0.shape.slice(0), array0.data, array0.stride, array0.offset | 0, array1.data, array1.stride, array1.offset | 0, array2.data, array2.stride, array2.offset | 0) + } + }, + + zero: function(compile) { + var CACHED = {} + return function zero_cwise_thunk(array0) { + var t0 = array0.dtype, r0 = array0.order, type = [t0, r0.join()].join(), proc = CACHED[type] + if (!proc) { CACHED[type] = proc = compile([t0, r0]) } return proc(array0.shape.slice(0), array0.data, array0.stride, array0.offset | 0) + } + }, + + fdTemplate1: function(compile) { + var CACHED = {} + return function fdTemplate1_cwise_thunk(array0, array1) { + var t0 = array0.dtype, r0 = array0.order, t1 = array1.dtype, r1 = array1.order, type = [t0, r0.join(), t1, r1.join()].join(), proc = CACHED[type] + if (!proc) { CACHED[type] = proc = compile([t0, r0, t1, r1]) } return proc(array0.shape.slice(0), array0.data, array0.stride, array0.offset | 0, array1.data, array1.stride, array1.offset | 0) + } + }, + + fdTemplate2: function(compile) { + var CACHED = {} + return function fdTemplate2_cwise_thunk(array0, array1, array4) { + var t0 = array0.dtype, r0 = array0.order, t1 = array1.dtype, r1 = array1.order, t4 = array4.dtype, r4 = array4.order, type = [t0, r0.join(), t1, r1.join(), t4, r4.join()].join(), proc = CACHED[type] + if (!proc) { CACHED[type] = proc = compile([t0, r0, t1, r1, t4, r4]) } return proc(array0.shape.slice(0), array0.data, array0.stride, array0.offset | 0, array1.data, array1.stride, array1.offset | 0, array4.data, array4.stride, array4.offset | 0) + } + }, +} + +function createThunk(proc) { + var thunk = CACHED_thunk[proc.funcName] + return thunk(compile.bind(undefined, proc)) +} + +function compile(proc) { + return CACHED_CWiseOp[proc.funcName] +} + +function cwiseCompiler(user_args) { + return createThunk({ + funcName: user_args.funcName + }) +} -var dup = require('dup') -var cwiseCompiler = require('cwise-compiler') var TEMPLATE_CACHE = {} var GRADIENT_CACHE = {} @@ -16,48 +106,10 @@ var EmptyProc = { } var centralDiff = cwiseCompiler({ - args: [ 'array', 'array', 'array' ], - pre: EmptyProc, - post: EmptyProc, - body: { - args: [ { - name: 'out', - lvalue: true, - rvalue: false, - count: 1 - }, { - name: 'left', - lvalue: false, - rvalue: true, - count: 1 - }, { - name: 'right', - lvalue: false, - rvalue: true, - count: 1 - }], - body: "out=0.5*(left-right)", - thisVars: [], - localVars: [] - }, funcName: 'cdiff' }) var zeroOut = cwiseCompiler({ - args: [ 'array' ], - pre: EmptyProc, - post: EmptyProc, - body: { - args: [ { - name: 'out', - lvalue: true, - rvalue: false, - count: 1 - }], - body: "out=0", - thisVars: [], - localVars: [] - }, funcName: 'zero' }) @@ -65,224 +117,143 @@ function generateTemplate(d) { if(d in TEMPLATE_CACHE) { return TEMPLATE_CACHE[d] } - var code = [] - for(var i=0; i= 0) { - pickStr.push('0') - } else if(facet.indexOf(-(i+1)) >= 0) { - pickStr.push('s['+i+']-1') - } else { - pickStr.push('-1') - loStr.push('1') - hiStr.push('s['+i+']-2') - } +function CACHED_link(diff, zero, grad1, grad2) { + return function(dst, src) { + var s = src.shape.slice() + if (1 && s[0] > 2 && s[1] > 2) { + grad2( + src + .pick(-1, -1) + .lo(1, 1) + .hi(s[0] - 2, s[1] - 2), + dst + .pick(-1, -1, 0) + .lo(1, 1) + .hi(s[0] - 2, s[1] - 2), + dst + .pick(-1, -1, 1) + .lo(1, 1) + .hi(s[0] - 2, s[1] - 2) + ) } - var boundStr = '.lo(' + loStr.join() + ').hi(' + hiStr.join() + ')' - if(loStr.length === 0) { - boundStr = '' + if (1 && s[1] > 2) { + grad1( + src + .pick(0, -1) + .lo(1) + .hi(s[1] - 2), + dst + .pick(0, -1, 1) + .lo(1) + .hi(s[1] - 2) + ) + zero( + dst + .pick(0, -1, 0) + .lo(1) + .hi(s[1] - 2) + ) } - - if(cod > 0) { - code.push('if(1') - for(var i=0; i= 0 || facet.indexOf(-(i+1)) >= 0) { - continue - } - code.push('&&s[', i, ']>2') - } - code.push('){grad', cod, '(src.pick(', pickStr.join(), ')', boundStr) - for(var i=0; i= 0 || facet.indexOf(-(i+1)) >= 0) { - continue - } - code.push(',dst.pick(', pickStr.join(), ',', i, ')', boundStr) - } - code.push(');') + if (1 && s[1] > 2) { + grad1( + src + .pick(s[0] - 1, -1) + .lo(1) + .hi(s[1] - 2), + dst + .pick(s[0] - 1, -1, 1) + .lo(1) + .hi(s[1] - 2) + ) + zero( + dst + .pick(s[0] - 1, -1, 0) + .lo(1) + .hi(s[1] - 2) + ) } - - for(var i=0; i1){dst.set(', - pickStr.join(), ',', bnd, ',0.5*(src.get(', - cPickStr.join(), ')-src.get(', - dPickStr.join(), ')))}else{dst.set(', - pickStr.join(), ',', bnd, ',0)};') - } else { - code.push('if(s[', bnd, ']>1){diff(', outStr, - ',src.pick(', cPickStr.join(), ')', boundStr, - ',src.pick(', dPickStr.join(), ')', boundStr, - ');}else{zero(', outStr, ');};') - } - break - - case 'mirror': - if(cod === 0) { - code.push('dst.set(', pickStr.join(), ',', bnd, ',0);') - } else { - code.push('zero(', outStr, ');') - } - break - - case 'wrap': - var aPickStr = pickStr.slice() - var bPickStr = pickStr.slice() - if(facet[i] < 0) { - aPickStr[bnd] = 's[' + bnd + ']-2' - bPickStr[bnd] = '0' - - } else { - aPickStr[bnd] = 's[' + bnd + ']-1' - bPickStr[bnd] = '1' - } - if(cod === 0) { - code.push('if(s[', bnd, ']>2){dst.set(', - pickStr.join(), ',', bnd, ',0.5*(src.get(', - aPickStr.join(), ')-src.get(', - bPickStr.join(), ')))}else{dst.set(', - pickStr.join(), ',', bnd, ',0)};') - } else { - code.push('if(s[', bnd, ']>2){diff(', outStr, - ',src.pick(', aPickStr.join(), ')', boundStr, - ',src.pick(', bPickStr.join(), ')', boundStr, - ');}else{zero(', outStr, ');};') - } - break - - default: - throw new Error('ndarray-gradient: Invalid boundary condition') - } + if (1 && s[0] > 2) { + grad1( + src + .pick(-1, 0) + .lo(1) + .hi(s[0] - 2), + dst + .pick(-1, 0, 0) + .lo(1) + .hi(s[0] - 2) + ) + zero( + dst + .pick(-1, 0, 1) + .lo(1) + .hi(s[0] - 2) + ) } - - if(cod > 0) { - code.push('};') + if (1 && s[0] > 2) { + grad1( + src + .pick(-1, s[1] - 1) + .lo(1) + .hi(s[0] - 2), + dst + .pick(-1, s[1] - 1, 0) + .lo(1) + .hi(s[0] - 2) + ) + zero( + dst + .pick(-1, s[1] - 1, 1) + .lo(1) + .hi(s[0] - 2) + ) } + dst.set(0, 0, 0, 0) + dst.set(0, 0, 1, 0) + dst.set(s[0] - 1, 0, 0, 0) + dst.set(s[0] - 1, 0, 1, 0) + dst.set(0, s[1] - 1, 0, 0) + dst.set(0, s[1] - 1, 1, 0) + dst.set(s[0] - 1, s[1] - 1, 0, 0) + dst.set(s[0] - 1, s[1] - 1, 1, 0) + return dst } +} - //Enumerate ridges, facets, etc. of hypercube - for(var i=0; i<(1<