import * as twgl from 'twgl.js';
import { ViewManager } from "../../../view-manager";
import * as seedingShaders from './propagation-shaders-seeding';
import * as jfaShaders from './propagation-shaders-jfa';
import * as vectorToSdfShaders from './propagation-shaders-vector-to-sdf';
import { Roi } from "../../../../dicom/structure-set";
import { createVertexBuffers } from "../../create-vertex-buffers";
import * as m4 from '../../../../math/m4';
import { VectorField } from '../vector-field';
import { Sdf } from '../sdf';

export class Propagation {

    private viewManager: ViewManager;

    private seedingProgram: WebGLProgram;
    private seedingUniformLoc: any;
    private jfaProgram: WebGLProgram;
    private jfaUniformLoc: any;
    private vectorToSdfProgram: WebGLProgram;
    private vectorToSdfUniformLoc: any;

    constructor(viewManager: ViewManager) {
        this.viewManager = viewManager;
        let gl = viewManager.getWebGlContext();
        this.seedingProgram = twgl.createProgramInfo((gl as any), [seedingShaders.SEEDING_VS, seedingShaders.SEEDING_FS]).program;
        this.seedingUniformLoc = {
            textureMatrix: gl.getUniformLocation(this.seedingProgram, 'orientation'),
            textureData: gl.getUniformLocation(this.seedingProgram, 'textureData'),
            textureSize: gl.getUniformLocation(this.seedingProgram, 'textureSize'),
        }
        this.jfaProgram = twgl.createProgramInfo((gl as any), [jfaShaders.JFA_VS, jfaShaders.JFA_FS]).program;
        this.jfaUniformLoc = {
            textureMatrix: gl.getUniformLocation(this.jfaProgram, 'orientation'),
            textureData: gl.getUniformLocation(this.jfaProgram, 'textureData'),
            textureSize: gl.getUniformLocation(this.jfaProgram, 'textureSize'),
            distancePixels: gl.getUniformLocation(this.jfaProgram, 'distancePixels'),
            maxDistancePixels: gl.getUniformLocation(this.jfaProgram, 'maxDistancePixels'),
        }
        this.vectorToSdfProgram = twgl.createProgramInfo((gl as any), [vectorToSdfShaders.VECTOR_TO_SDF_VS, vectorToSdfShaders.VECTOR_TO_SDF_FS]).program;
        this.vectorToSdfUniformLoc = {
            textureMatrix: gl.getUniformLocation(this.vectorToSdfProgram, 'orientation'),
            vectorTexture: gl.getUniformLocation(this.vectorToSdfProgram, 'vectorTexture'),
            sdfTexture: gl.getUniformLocation(this.vectorToSdfProgram, 'sdfTexture'),
            maxDistancePixels: gl.getUniformLocation(this.vectorToSdfProgram, 'maxDistancePixels'),
        }
    }

    public propagate(roi: Roi ) {
        let d1 = Date.now();
        const vm = this.viewManager;
        if(!roi.sdf) return;
        const sourceSdf = roi.sdf;
        const resolutionMm = roi.sdf.resolutionMm;
        const maxDistanceMm = roi.sdf.maxDistanceMm;
        const gl = vm.getWebGlContext();
        const bb = sourceSdf.boundingBox;
        const size = sourceSdf.size;

        // Create two vector field buffers that will be ping-ponged during JFA
        const vectorField1 = new VectorField(vm, resolutionMm, maxDistanceMm);
        vectorField1.createTexture(bb.copy());
        const vectorField2 = new VectorField(vm, resolutionMm, maxDistanceMm);
        vectorField2.createTexture(bb.copy());

        const destSdf = new Sdf(vm, resolutionMm);
        destSdf.createTexture(bb.copy(), false, false);

        const pixelsPerMm = size[0] / bb.getXSize();
        const maxDistancePixels = maxDistanceMm * pixelsPerMm;

        for(let z = 0; z < size[2]; ++z) {
            let left = 0;
            let top = 0;
            let width = size[0];
            let height = size[1];
            
            let imageVertexPositions = [
                -1, -1,
                1, -1,
                1, 1,
                1, 1,
                -1, 1,
                -1, -1
            ];
            let imageTextureCoords = [
                0.0, 1.0,
                1.0, 1.0,
                1.0, 0.0,
                1.0, 0.0,
                0.0, 0.0,
                0.0, 1.0
            ];
    
            let step = 1 / (size[2]);
            const scroll = (0.5 + z) * step;
            let m = m4.translation(0, 1, scroll );
            m = m4.xRotate(m, Math.PI);

            // 1. Seeding
            gl.useProgram(this.seedingProgram);

            let fb = gl.createFramebuffer();
            gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
            gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, vectorField1.data, 0, z);
            
            gl.viewport(left, top, width, height);

            gl.uniform1i(this.seedingUniformLoc.textureData, 0);
            gl.activeTexture(gl.TEXTURE0);
            gl.bindTexture(gl.TEXTURE_3D, sourceSdf.data);

            gl.uniformMatrix4fv(this.seedingUniformLoc.textureMatrix, false, m);
            gl.uniform3fv(this.seedingUniformLoc.textureSize, size);
            gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
            gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
            gl.bindFramebuffer(gl.FRAMEBUFFER, null);
            
            // 2. Jump Flooding Algorithm (JFA)

            gl.useProgram(this.jfaProgram);

            let latestBuffer = vectorField1;

            let distancePixels = 1;
            while(distancePixels * 2 <= maxDistancePixels){
                distancePixels *= 2;
            }

            for( ; distancePixels >= 1; distancePixels /= 2) {

                fb = gl.createFramebuffer();
                gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
                gl.uniform1i(this.jfaUniformLoc.textureData, 0);
                gl.activeTexture(gl.TEXTURE0);
                if(latestBuffer === vectorField1) {
                    gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, vectorField2.data, 0, z);
                    gl.bindTexture(gl.TEXTURE_3D, vectorField1.data);
                    latestBuffer = vectorField2;
                }
                else {
                    gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, vectorField1.data, 0, z);
                    gl.bindTexture(gl.TEXTURE_3D, vectorField2.data);
                    latestBuffer = vectorField1;
                }

                gl.uniform1f(this.jfaUniformLoc.distancePixels, distancePixels);
                gl.uniform1f(this.jfaUniformLoc.maxDistancePixels, maxDistancePixels);

                gl.viewport(left, top, width, height);
                gl.uniformMatrix4fv(this.jfaUniformLoc.textureMatrix, false, m);
                gl.uniform3fv(this.jfaUniformLoc.textureSize, size);
                gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
                gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
                gl.bindFramebuffer(gl.FRAMEBUFFER, null);
            }

            // 3. Vector field to SDF
            gl.useProgram(this.vectorToSdfProgram);

            fb = gl.createFramebuffer();
            gl.bindFramebuffer(gl.FRAMEBUFFER, fb);
            gl.framebufferTextureLayer(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, destSdf.data, 0, z);

            gl.viewport(left, top, width, height);

            gl.uniform1i(this.vectorToSdfUniformLoc.vectorTexture, 0);
            gl.activeTexture(gl.TEXTURE0);
            gl.bindTexture(gl.TEXTURE_3D, latestBuffer.data);

            gl.uniform1i(this.vectorToSdfUniformLoc.sdfTexture, 1);
            gl.activeTexture(gl.TEXTURE1);
            gl.bindTexture(gl.TEXTURE_3D, sourceSdf.data);

            gl.uniform1f(this.vectorToSdfUniformLoc.maxDistancePixels, maxDistancePixels);

            gl.uniformMatrix4fv(this.vectorToSdfUniformLoc.textureMatrix, false, m);
            gl.bindVertexArray(createVertexBuffers(gl, imageVertexPositions, imageTextureCoords));
            gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, 1);
            gl.bindFramebuffer(gl.FRAMEBUFFER, null);
        }

        roi.sdf = destSdf;
        roi.sdf.propagationPending = false;
        
        const d2 = Date.now();
        console.log("Propagation took " + ( d2 - d1) + " milliseconds");
        vm.viewerState.notifyListeners();
    
    }
}