import * as twgl from 'twgl.js';
import { ViewManager } from "../../../view-manager";
import * as xShaders from './margin-shaders-sweep-x';
import * as yShaders from './margin-shaders-sweep-y';
import * as zShaders from './margin-shaders-sweep-z';
import * as smoothingShaders from './margin-shaders-smoothing';
import { Roi, StructureSet } from "../../../../dicom/structure-set";
import { createVertexBuffers } from "../../create-vertex-buffers";
import * as m4 from '../../../../math/m4';
import { Sdf } from '../sdf';
import { Propagation } from '../propagation/sdf-propagation';
import { SdfOperations } from '../boolean/sdf-operations';
import { SdfSlice } from '../sdf-slice';
import { mouseTools } from '../../../mouse-tools/mouse-tools';

export const maxMarginMM = 50;

export class MarginOptions {
    sourceRoi: Roi;
    targetRoi: Roi;
    isInnerMargin: boolean;
    
    leftMM: number;
    rightMM: number;
    cranialMM: number;
    caudalMM: number;
    anteriorMM: number;
    posteriorMM: number;

    constructor(sourceRoi: Roi, targetRoi: Roi,
         isInner: boolean, leftMM: number, rightMM: number, cranialMM: number, caudalMM: number, 
         anteriorMM: number, posteriorMM: number) {

        if(!sourceRoi.sdf) throw new Error("Source structure '" + sourceRoi.name + "' does not have contours!");
        [leftMM, rightMM, cranialMM, caudalMM, anteriorMM, posteriorMM].forEach( (m: number) => {
            if(m < 0 || m > maxMarginMM) throw new Error("Margins must be between 0 and " + maxMarginMM.toString() + " mm") 
        });

        this.sourceRoi = sourceRoi;
        this.targetRoi = targetRoi;
        this.isInnerMargin = isInner;

        this.leftMM = leftMM;
        this.rightMM = rightMM;
        this.cranialMM = cranialMM;
        this.caudalMM = caudalMM;
        this.anteriorMM = anteriorMM;
        this.posteriorMM = posteriorMM;
    }

}

export class MarginOperations {

    private readonly viewManager: ViewManager;
    private readonly xProgram: WebGLProgram;
    private xUniformLoc: any;
    private readonly yProgram: WebGLProgram;
    private yUniformLoc: any;
    private readonly zProgram: WebGLProgram;
    private zUniformLoc: any;
    private readonly smoothingProgram: WebGLProgram;
    private smoothingUniformLoc: any;
    
    constructor(viewManager: ViewManager) {
        this.viewManager = viewManager;
        let gl = viewManager.getWebGlContext();
        this.xProgram = twgl.createProgramInfo((gl as any), [xShaders.MARGIN_SWEEP_X_VS, xShaders.MARGIN_SWEEP_X_FS]).program;
        this.xUniformLoc = {
            textureOrig: gl.getUniformLocation(this.xProgram, 'textureOrig'),
            textureBuffer: gl.getUniformLocation(this.xProgram, 'textureBuffer'),
            textureSize: gl.getUniformLocation(this.xProgram, 'textureSize'),
            maxDistancePixels: gl.getUniformLocation(this.xProgram, 'maxDistancePixels'),
            distancePixels1: gl.getUniformLocation(this.xProgram, 'distancePixels1'),
            distancePixels2: gl.getUniformLocation(this.xProgram, 'distancePixels2'),
            x1MarginPixels: gl.getUniformLocation(this.xProgram, 'x1MarginPixels'),
            x2MarginPixels: gl.getUniformLocation(this.xProgram, 'x2MarginPixels'),
            isInnerMargin: gl.getUniformLocation(this.xProgram, 'isInnerMargin'),
        }

        this.yProgram = twgl.createProgramInfo((gl as any), [yShaders.MARGIN_SWEEP_Y_VS, yShaders.MARGIN_SWEEP_Y_FS]).program;
        this.yUniformLoc = {
            textureOrig: gl.getUniformLocation(this.yProgram, 'textureOrig'),
            textureBuffer: gl.getUniformLocation(this.yProgram, 'textureBuffer'),
            textureSize: gl.getUniformLocation(this.yProgram, 'textureSize'),
            maxDistancePixels: gl.getUniformLocation(this.yProgram, 'maxDistancePixels'),
            distancePixels1: gl.getUniformLocation(this.yProgram, 'distancePixels1'),
            distancePixels2: gl.getUniformLocation(this.yProgram, 'distancePixels2'),
            y1MarginPixels: gl.getUniformLocation(this.yProgram, 'y1MarginPixels'),
            y2MarginPixels: gl.getUniformLocation(this.yProgram, 'y2MarginPixels'),
            isInnerMargin: gl.getUniformLocation(this.yProgram, 'isInnerMargin'),
        }

        this.zProgram = twgl.createProgramInfo((gl as any), [zShaders.MARGIN_SWEEP_Z_VS, zShaders.MARGIN_SWEEP_Z_FS]).program;
        this.zUniformLoc = {
            textureOrig: gl.getUniformLocation(this.zProgram, 'textureOrig'),
            textureBuffer: gl.getUniformLocation(this.zProgram, 'textureBuffer'),
            isInnerMargin: gl.getUniformLocation(this.zProgram, 'isInnerMargin'),
            textureMatrix1: gl.getUniformLocation(this.zProgram, 'orientation1'),
            textureMatrix2: gl.getUniformLocation(this.zProgram, 'orientation2'),
            textureMatrix3: gl.getUniformLocation(this.zProgram, 'orientation3'),
            textureMatrix4: gl.getUniformLocation(this.zProgram, 'orientation4'),
            maxDistancePixels: gl.getUniformLocation(this.zProgram, 'maxDistancePixels'),
            distancePixels1: gl.getUniformLocation(this.zProgram, 'distancePixels1'),
            distancePixels2: gl.getUniformLocation(this.zProgram, 'distancePixels2'),
            z1MarginPixels: gl.getUniformLocation(this.zProgram, 'z1MarginPixels'),
            z2MarginPixels: gl.getUniformLocation(this.zProgram, 'z2MarginPixels'),
            interpolationWeight1: gl.getUniformLocation(this.zProgram, 'interpolationWeight1'),
            interpolationWeight2: gl.getUniformLocation(this.zProgram, 'interpolationWeight2'),
        }

        this.smoothingProgram = twgl.createProgramInfo((gl as any), [smoothingShaders.MARGIN_SMOOTHING_VS, smoothingShaders.MARGIN_SMOOTHING_FS]).program;
        this.smoothingUniformLoc = {
            textureData: gl.getUniformLocation(this.smoothingProgram, 'textureData'),
            textureSize: gl.getUniformLocation(this.smoothingProgram, 'textureSize'),
        }
    }

    public addMargin( opt: MarginOptions ) {
        let d1 = Date.now();
        const vm = this.viewManager;
        const img = vm.image;

        let sourceSdf = opt.sourceRoi.sdf as Sdf;
        if(!sourceSdf) throw new Error("Source structure does not have contours!");

        const resolutionMm = sourceSdf.resolutionMm;
        const maxDistanceMm = sourceSdf.maxDistanceMm;
        const gl = vm.getWebGlContext();
        const bb = sourceSdf.boundingBox.copy();

        if(!opt.isInnerMargin) {
            bb.minI -= opt.rightMM;
            bb.maxI += opt.leftMM;
            bb.minJ -= opt.anteriorMM;
            bb.maxJ += opt.posteriorMM;
            bb.minK -= opt.caudalMM;
            bb.maxK += opt.cranialMM;
            bb.cropToImageBorderMM(img);
            bb.roundToFullPixels(img);
        }

        if(sourceSdf.propagationPending) {
            new Propagation(this.viewManager).propagate(opt.sourceRoi);
        }
        
        vm.viewerState.undoStack.pushRoiStateBeforeEdit(opt.targetRoi);

        const sdfOp = new SdfOperations(vm);
        sourceSdf = sdfOp.copy(sourceSdf, bb, true);
        const destSdf = new Sdf(vm, resolutionMm);
        destSdf.createTexture(bb, false, false);
        const size = sourceSdf.size;

        // Create two slice buffers that will be ping-ponged during the margin "sweep" algorithm
        const sliceBuffer1 = new SdfSlice(vm, resolutionMm, maxDistanceMm);
        sliceBuffer1.createTexture(size[0], size[1]);
        const sliceBuffer2 = new SdfSlice(vm, resolutionMm, maxDistanceMm);
        sliceBuffer2.createTexture(size[0], size[1]);

        const origBuffer = new SdfSlice(vm, resolutionMm, maxDistanceMm);
        origBuffer.createTexture(size[0], size[1]);

        let left = 0;
        let top = 0;
        let width = size[0];
        let height = size[1];
        
        const imageVertexPositions = [
            -1, -1,
            1, -1,
            1, 1,
            1, 1,
            -1, 1,
            -1, -1
        ];
        const imageTextureCoords = [
            0.0, 1.0,
            1.0, 1.0,
            1.0, 0.0,
            1.0, 0.0,
            0.0, 0.0,
            0.0, 1.0
        ];

        for(let z = 0; z < size[2]; ++z) {
            // 1. Sweep in X axis direction (add left and right margin)
            if(opt.leftMM > 0 || opt.rightMM > 0) {
                sdfOp.copySdfSlice(sourceSdf, z, origBuffer);
                sdfOp.copySdfSlice(sourceSdf, z, sliceBuffer1);
                let latestBuffer = sliceBuffer1;
                gl.useProgram(this.xProgram);

                const xPixelsPerMm = size[0] / bb.getXSize();
                const maxDistancePixels = maxDistanceMm * xPixelsPerMm;
                const x1MarginPixels = opt.leftMM * xPixelsPerMm;
                const x2MarginPixels = opt.rightMM * xPixelsPerMm;
                const xIterations = Math.ceil( Math.max(x1MarginPixels, x2MarginPixels) );

                for( let distancePixels = 1; distancePixels <= xIterations; distancePixels++) {
                    const fb = gl.createFramebuffer();
                    gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
                    gl.uniform1i(this.xUniformLoc.textureBuffer, 0);
                    gl.activeTexture(gl.TEXTURE0);
                    if(latestBuffer === sliceBuffer1) {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer2.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer1.data);
                        latestBuffer = sliceBuffer2;
                    }
                    else {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer1.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer2.data);
                        latestBuffer = sliceBuffer1;
                    }
                    if(distancePixels == xIterations) {
                        gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, sourceSdf.data, 0, z);
                    }
                    gl.uniform1i(this.xUniformLoc.textureOrig, 1);
                    gl.activeTexture(gl.TEXTURE1);
                    gl.bindTexture(gl.TEXTURE_2D, origBuffer.data);

                    let dist1 = 0;
                    if( distancePixels <= x1MarginPixels) {
                        dist1 = distancePixels;
                    }
                    else if (distancePixels > x1MarginPixels && distancePixels < x1MarginPixels + 1.0) {
                        dist1 = x1MarginPixels;
                    }

                    let dist2 = 0;
                    if( distancePixels <= x2MarginPixels) {
                        dist2 = distancePixels;
                    }
                    else if (distancePixels > x2MarginPixels && distancePixels < x2MarginPixels + 1) {
                        dist2 = x2MarginPixels;
                    }

                    gl.uniform2fv(this.xUniformLoc.textureSize, [size[0], size[1]]);
                    gl.uniform1f(this.xUniformLoc.maxDistancePixels, maxDistancePixels);
                    gl.uniform1f(this.xUniformLoc.distancePixels1, dist1);
                    gl.uniform1f(this.xUniformLoc.distancePixels2, dist2);
                    gl.uniform1f(this.xUniformLoc.x1MarginPixels, x1MarginPixels);
                    gl.uniform1f(this.xUniformLoc.x2MarginPixels, x2MarginPixels);
                    gl.uniform1f(this.xUniformLoc.isInnerMargin, opt.isInnerMargin ? 1 : 0);

                    gl.viewport(left, top, width, height);
                    
                    gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
                    gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
                    gl.bindFramebuffer(gl.FRAMEBUFFER, null);
                }
            }

            // 2. Sweep in Y axis direction (add top and bottom margin)
            if(opt.anteriorMM > 0 || opt.posteriorMM > 0) {
                sdfOp.copySdfSlice(sourceSdf, z, origBuffer);
                sdfOp.copySdfSlice(sourceSdf, z, sliceBuffer1);
                let latestBuffer = sliceBuffer1;
                gl.useProgram(this.yProgram);

                const yPixelsPerMm = size[1] / bb.getYSize();
                const maxDistancePixels = maxDistanceMm * yPixelsPerMm;
                const y1MarginPixels = opt.posteriorMM * yPixelsPerMm;
                const y2MarginPixels = opt.anteriorMM * yPixelsPerMm;
                const yIterations = Math.ceil( Math.max(y1MarginPixels, y2MarginPixels) );

                for( let distancePixels = 1; distancePixels <= yIterations; distancePixels++) {
                    const fb = gl.createFramebuffer();
                    gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
                    gl.uniform1i(this.yUniformLoc.textureBuffer, 0);
                    gl.activeTexture(gl.TEXTURE0);
                    if(latestBuffer === sliceBuffer1) {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer2.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer1.data);
                        latestBuffer = sliceBuffer2;
                    }
                    else {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer1.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer2.data);
                        latestBuffer = sliceBuffer1;
                    }
                    if(distancePixels == yIterations) {
                        gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, sourceSdf.data, 0, z);
                    }
                    gl.uniform1i(this.yUniformLoc.textureOrig, 1);
                    gl.activeTexture(gl.TEXTURE1);
                    gl.bindTexture(gl.TEXTURE_2D, origBuffer.data);

                    gl.uniform2fv(this.yUniformLoc.textureSize, [size[0], size[1]]);
                    gl.uniform1f(this.yUniformLoc.maxDistancePixels, maxDistancePixels);

                    let dist1 = 0;
                    if( distancePixels <= y1MarginPixels) {
                        dist1 = distancePixels;
                    }
                    else if (distancePixels > y1MarginPixels && distancePixels < y1MarginPixels + 1) {
                        dist1 = y1MarginPixels;
                    }

                    let dist2 = 0;
                    if( distancePixels <= y2MarginPixels) {
                       dist2 = distancePixels;
                    }
                    else if (distancePixels > y2MarginPixels && distancePixels < y2MarginPixels + 1) {
                       dist2 = y2MarginPixels;
                    }

                    gl.uniform1f(this.yUniformLoc.distancePixels1, dist1);
                    gl.uniform1f(this.yUniformLoc.distancePixels2, dist2);
                    gl.uniform1f(this.yUniformLoc.y1MarginPixels, y1MarginPixels);
                    gl.uniform1f(this.yUniformLoc.y2MarginPixels, y2MarginPixels);
                    gl.uniform1f(this.yUniformLoc.isInnerMargin, opt.isInnerMargin ? 1 : 0);

                    gl.viewport(left, top, width, height);
                    
                    gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
                    gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
                    gl.bindFramebuffer(gl.FRAMEBUFFER, null);
                }
            }

        }

        // Z direction
        const getOrientation = (zSlice: number) => {
            let step = 1 / (size[2]);
            const scroll = (0.5 + zSlice) * step;
            let m = m4.translation(0, 1, scroll );
            m = m4.xRotate(m, Math.PI);
            return m;
        }

        if( opt.cranialMM > 0 || opt.caudalMM > 0 ) {
            for(let z = 0; z < size[2]; ++z) {
                sdfOp.copySdfSlice(sourceSdf, z, sliceBuffer1);
                let latestBuffer = sliceBuffer1;
                gl.useProgram(this.zProgram);

                const zPixelsPerMm = size[2] / bb.getZSize();
                const z1MarginPixels = opt.cranialMM * zPixelsPerMm;
                const z2MarginPixels = opt.caudalMM * zPixelsPerMm;
                const zIterations = Math.ceil( Math.max(z1MarginPixels, z2MarginPixels) );

                for( let distancePixels = 1; distancePixels <= zIterations; distancePixels++) {
                    const fb = gl.createFramebuffer();
                    gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
                    gl.uniform1i(this.zUniformLoc.textureBuffer, 0);
                    gl.activeTexture(gl.TEXTURE0);
                    if(latestBuffer === sliceBuffer1) {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer2.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer1.data);
                        latestBuffer = sliceBuffer2;
                    }
                    else {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer1.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer2.data);
                        latestBuffer = sliceBuffer1;
                    }
                    if(distancePixels == zIterations) {
                        gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, destSdf.data, 0, z);
                    }
                    gl.uniform1i(this.zUniformLoc.textureOrig, 1);
                    gl.activeTexture(gl.TEXTURE1);
                    gl.bindTexture(gl.TEXTURE_3D, sourceSdf.data);

                    const z1 = opt.isInnerMargin ? z + distancePixels : z - distancePixels;
                    const z2 = opt.isInnerMargin ? z1 - 1 : z1 + 1;

                    gl.uniformMatrix4fv(this.zUniformLoc.textureMatrix1, false, getOrientation(z1));
                    gl.uniformMatrix4fv(this.zUniformLoc.textureMatrix2, false, getOrientation(z2));

                    if(z1 < 0 || z1 >= size[2] || distancePixels >= z1MarginPixels + 1) {
                        gl.uniform1f(this.zUniformLoc.interpolationWeight1, -1);
                    }
                    else if(distancePixels > z1MarginPixels && distancePixels < z1MarginPixels + 1) {
                        const weight = distancePixels - z1MarginPixels;  
                        gl.uniform1f(this.zUniformLoc.interpolationWeight1, weight);
                        gl.uniform1f(this.zUniformLoc.distancePixels1, z1MarginPixels);
                    }
                    else if(distancePixels <= z1MarginPixels) {
                        gl.uniform1f(this.zUniformLoc.interpolationWeight1, 0);
                        gl.uniform1f(this.zUniformLoc.distancePixels1, distancePixels);
                    }

                    const z3 = !opt.isInnerMargin ? z + distancePixels : z - distancePixels;
                    const z4 = !opt.isInnerMargin ? z3 - 1 : z3 + 1;

                    gl.uniformMatrix4fv(this.zUniformLoc.textureMatrix3, false, getOrientation(z3));
                    gl.uniformMatrix4fv(this.zUniformLoc.textureMatrix4, false, getOrientation(z4));

                    if( z3 < 0 || z3 >= size[2] || distancePixels >= z2MarginPixels + 1) {
                        gl.uniform1f(this.zUniformLoc.interpolationWeight2, -1);
                    }
                    else if(distancePixels > z2MarginPixels && distancePixels < z2MarginPixels + 1) {
                        const weight =  (distancePixels - z2MarginPixels);  
                        gl.uniform1f(this.zUniformLoc.interpolationWeight2, weight);
                        gl.uniform1f(this.zUniformLoc.distancePixels2, z2MarginPixels);
                    }
                    else if(distancePixels <= z2MarginPixels) {
                        gl.uniform1f(this.zUniformLoc.interpolationWeight2, 0);
                        gl.uniform1f(this.zUniformLoc.distancePixels2, distancePixels);
                    }
                
                    gl.uniform1f(this.zUniformLoc.z1MarginPixels, z1MarginPixels);
                    gl.uniform1f(this.zUniformLoc.z2MarginPixels, z2MarginPixels);

                    gl.uniform1f(this.zUniformLoc.isInnerMargin, opt.isInnerMargin ? 1 : 0);

                    gl.viewport(left, top, width, height);
                    
                    gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
                    gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
                    gl.bindFramebuffer(gl.FRAMEBUFFER, null);
                }
            }
            
            let smoothingRounds = 7;
            for(let z = 0; z < size[2]; ++z) {
                sdfOp.copySdfSlice(destSdf, z, sliceBuffer1);
                let latestBuffer = sliceBuffer1;
                gl.useProgram(this.smoothingProgram);
                for(let i = 1; i <= smoothingRounds; ++i) {
                    let fb = gl.createFramebuffer();
                    gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
                    gl.uniform1i(this.smoothingUniformLoc.textureData, 0);
                    gl.activeTexture(gl.TEXTURE0);
                    if(latestBuffer === sliceBuffer1) {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer2.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer1.data);
                        latestBuffer = sliceBuffer2;
                    }
                    else {
                        gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, sliceBuffer1.data, 0);
                        gl.bindTexture(gl.TEXTURE_2D, sliceBuffer2.data);
                        latestBuffer = sliceBuffer1;
                    }
                    if(i == smoothingRounds) {
                        gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, sourceSdf.data, 0, z);
                    }
                    
                    gl.viewport(left, top, width, height);
                    gl.uniform2fv(this.smoothingUniformLoc.textureSize, [size[0], size[1]] );
                    gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
                    gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
                    gl.bindFramebuffer(gl.FRAMEBUFFER, null);
                }
            }

        }

        opt.targetRoi.sdf = sourceSdf;
        opt.targetRoi.setContoursChanged(null);
        opt.targetRoi.structureSet.unsaved = true;
        new Propagation(this.viewManager).propagate(opt.targetRoi);
        
        if(mouseTools.brush.brushBuffer) {
            mouseTools.brush.createDrawBuffer();
        }
        
        const d2 = Date.now();
        console.log("Adding margin took " + ( d2 - d1) + " milliseconds");
        vm.viewerState.notifyListeners();
    
    }
}