/* Copyright (C) 2001, 2007 United States Government as represented by
   the Administrator of the National Aeronautics and Space Administration.
   All Rights Reserved.
 */
package gov.nasa.worldwind.servers.wms.utilities;

import java.io.*;
import java.awt.image.*;
import java.awt.*;

/**
 * @author brownrigg
 * @version $Id$
 */

public class WaveletCodec {

    /**
     * A suggested filename extension for wavelet-encodings. 
     */
    public static final String WVT_EXT = ".wvt";

    /**
     * Loads a previously persisted wavelet encoding from the given file.
     *
     * @param file
     * @return
     * @throws IOException
     */
    public static WaveletCodec loadFully(File file) throws IOException {

        DataInputStream inp = null;
        try {
            inp = new DataInputStream(new BufferedInputStream(
                new FileInputStream(file)));

            WaveletCodec codec = new WaveletCodec();
            codec.resolutionX = inp.readInt();
            codec.resolutionY = inp.readInt();
            int imageType = inp.readInt();
            int numBands = inp.readInt();
            codec.xform = new byte[numBands][codec.resolutionX * codec.resolutionY];
            for (int k=0; k<numBands; k++)
                inp.readFully(codec.xform[k]);

            return codec;
        } finally {
            if (inp != null) inp.close();
        }
    }


    /**
     * Partially loads a previously persisted wavelet encoding from the given file, upto the given resolution.
     *
     * @param file
     * @param resolution
     * @return
     * @throws IOException
     * @throws IllegalArgumentException
     */
    public static WaveletCodec loadPartially(File file, int resolution) throws IOException, IllegalArgumentException {
        if (!WaveletCodec.isPowerOfTwo(resolution))
            throw new IllegalArgumentException("WaveletCodec.loadPartially(): input resolution not a power of two.");

        // NOTE: the try-finally clause was introduced because we had observed cases where, if an
        // exception was thrown during the read, the file would remain open, eventually leading
        // to the process exceeding its maximum open files.
        RandomAccessFile inp = null;
        try {
            inp = new RandomAccessFile(file, "r");

            WaveletCodec codec = new WaveletCodec();
            codec.resolutionX = inp.readInt();
            codec.resolutionY = inp.readInt();
            if (resolution > codec.resolutionX || resolution > codec.resolutionY)
                throw new IllegalArgumentException("WaveletCodec.loadPartially(): input resolution greater than encoded image");
            
            int imageType = inp.readInt();
            int numBands = inp.readInt();
            codec.xform = new byte[numBands][resolution*resolution];
            for (int k=0; k<numBands; k++) {
                inp.seek(4*(Integer.SIZE/Byte.SIZE) + k * (codec.resolutionX*codec.resolutionY));
                inp.readFully(codec.xform[k]);
            }

            return codec;

        } finally {
            if (inp != null) inp.close();
        }
    }

    /**
     * Creates a wavelet encoding from the given BufferedImage. The image must have dimensions that are
     * a power of 2. If the incoming image has at least 3 bands, the first three are assumed to be RGB channels.
     * If only one-band, it is assumed to be grayscale. The SampleModel component-type must be BYTE.
     *
     * @param image
     * @return
     * @throws IllegalArgumentException
     */
    public static WaveletCodec encode(BufferedImage image) throws IllegalArgumentException {

        if (image == null)
            throw new IllegalArgumentException("WaveletCodec.encode: null image");

        // Does image have the required resolution constraints?
        int xRes = image.getWidth();
        int yRes = image.getHeight();
        if (!isPowerOfTwo(xRes) || !isPowerOfTwo(yRes))
            throw new IllegalArgumentException("Image dimensions are not a power of 2");

        // Try to determine image type...
        SampleModel sampleModel = image.getSampleModel();
        int numBands = sampleModel.getNumBands();
        if ( !(numBands == 1 || numBands == 3) || sampleModel.getDataType() != DataBuffer.TYPE_BYTE)
            throw new IllegalArgumentException("Image is not of BYTE type, or not recognized as grayscale or RGB (alpha-channel is not supported)");

        // Looks good to go;  grab the image data.  We'll need to make a copy, as we need some
        // temp working space and we don't want to corrupt the BufferedImage's data...

        int bandSize = xRes * yRes;
        int next = 0;
        Raster rast = image.getRaster();
        float[] dataElems = new float[numBands];
        float[][] imageData = new float[numBands][bandSize];

        for (int j = 0; j < yRes; j++) {
            for (int i = 0; i < xRes; i++) {
                rast.getPixel(i, j, dataElems);
                for (int k = 0; k < numBands; k++) {
                    imageData[k][next] = dataElems[k];
                }
                ++next;
            }
        }

        // We need some temporary work space the size of the image...
        float[][] workspace = new float[numBands][bandSize];

        // Perform the transformation...
        int level = 0;
        int xformXres = xRes;
        int xformYres = yRes;

        while (true) {
            ++level;

            if ( !(xformXres > 0 || xformYres > 0)) break;
            int halfXformXres = xformXres / 2;
            int halfXformYres = xformYres / 2;

            // transform along the rows...
            for (int j = 0; j < xformYres; j++) {

                int offset = j * yRes;      // IMPORTANT THAT THIS REFLECT SOURCE IMAGE, NOT THE CURRENT LEVEL!

                for (int i = 0; i < halfXformXres; i++) {
                    int indx1 = offset + i*2;
                    int indx2 = offset + i*2 + 1;

                    // horizontally...
                    for (int k = 0; k < numBands; k++) {
                        float average = (imageData[k][indx1] + imageData[k][indx2]) / 2f;
                        float detail = imageData[k][indx1] - average;
                        workspace[k][offset + i] = average;
                        workspace[k][offset + i + halfXformXres] = detail;
                    }
                }

            }

            // copy transformed data from this iteration back into our source arrays...
            for (int k=0; k < numBands; k++)
                System.arraycopy(workspace[k], 0, imageData[k], 0, workspace[k].length);

            // now transform along columns...
            for (int j = 0; j < xformXres; j++) {
                for (int i = 0; i < halfXformYres; i++) {
                    int indx1 = j + (i*2)*yRes;
                    int indx2 = j + (i*2+1)*yRes;

                    // horizontally...
                    for (int k = 0; k < numBands; k++) {
                        float average = (imageData[k][indx1] + imageData[k][indx2]) / 2f;
                        float detail = imageData[k][indx1] - average;
                        workspace[k][j + i*yRes] = average;
                        workspace[k][j + (i+halfXformYres)*yRes] = detail;
                    }
                }

            }

            xformXres /= 2;
            xformYres /= 2;

            // copy transformed data from this iteration back into our source arrays...
            for (int k=0; k < numBands; k++)
                System.arraycopy(workspace[k], 0, imageData[k], 0, workspace[k].length);
        }

        // Our return WaveletCodec...
        WaveletCodec codec = new WaveletCodec();
        codec.resolutionX = xRes;
        codec.resolutionY = yRes;
        codec.imageType = (numBands == 1) ? EncodingType.GRAY_SCALE : EncodingType.COLOR_RGB;
        codec.xform = new byte[numBands][bandSize];

        //
        // Rearrange in memory for optimal, hierarchical layout on disk, quantizing down to
        // byte values as we go.
        //

        // NOTE: the first byte of each channel is different; it represents the average color of the
        // overall image, and as such should be an unsigned quantity in the range 0..255.
        // All other values are signed coefficents, so the clamping boundaries are different.
        for (int k=0; k<numBands; k++)
            codec.xform[k][0] = (byte) Math.min(255, Math.max(0, Math.round(imageData[k][0])));

        int scale = 1;   // actually inverse of the magnification level...
        next = 1;
        while (scale < xRes) {
            for (int subBlock = 0; subBlock < 3; subBlock++) {
                int colOffset = ((subBlock % 2) == 0) ? scale : 0;
                int rowOffset = (subBlock > 0) ? scale * xRes : 0;
                for (int j = 0; j < scale; j++) {
                    for (int i = 0; i < scale; i++, next++) {
                        int indx = rowOffset + colOffset + j*xRes + i;
                        for (int k = 0; k < numBands; k++) {
                           codec.xform[k][next] = (byte) Math.max(Byte.MIN_VALUE, Math.min(Byte.MAX_VALUE, Math.round(imageData[k][indx])));
                        }
                    }
                }
            }
            scale *= 2;
        }

        // Done!
        return codec;
    }


    /**
     * Reconstructs an image from this wavelet encoding at the given resolution. The specified resolution
     * must be a power of two, and must be less than or equal to the resolution of the encoding.
     *
     * This reconstruction algorithm was hinted at in:
     *
     *    "Principles of Digital Image Synthesis"
     *    Andrew Glassner
     *    1995, pp. 296
     *
     * @param resolution
     * @return reconstructed image.
     * @throws IllegalArgumentException
     */
    public BufferedImage reconstruct(int resolution) throws IllegalArgumentException {

        // Allocate memory for the BufferedImage
        int numBands = this.xform.length;
        int[][] imageData = new int[numBands][this.resolutionX * this.resolutionY];
        byte[][] imageBytes = new byte[numBands][this.resolutionX * this.resolutionY];

        // we need working buffers as large as 1/2 the output resolution...
        // Note how these are named after Glassner's convention...

        int res2 = (resolution/2) * (resolution/2);
        int[][] A = new int[numBands][res2];
        int[][] D = new int[numBands][res2];
        int[][] V = new int[numBands][res2];
        int[][] H = new int[numBands][res2];

        // Prime the process. Recall that the first byte of each channel is a color value, not
        // signed coefficients. So treat it as an unsigned value.
        for (int k=0; k < numBands; k++)
            imageData[k][0] = 0x000000ff & this.xform[k][0];

        int scale = 1;
        int offset = 1;
        do {
            // load up our A,D,V,H component arrays...
            int numVals = scale*scale;
            if (numVals >= resolution*resolution) break;

            int next = 0;
            for (int j=0; j<scale; j++) {
                for (int i=0; i<scale; i++, next++) {
                    for (int k=0; k<numBands; k++) {
                        A[k][next] = imageData[k][j*resolution + i];
                    }
                }
            }
            for (int i=0; i<numVals; i++, offset++) {
                for (int k=0; k<numBands; k++) {
                   H[k][i] = this.xform[k][offset];
                }
            }
            for (int i=0; i<numVals; i++, offset++) {
                for (int k=0; k<numBands; k++) {
                    V[k][i] = this.xform[k][offset];
                }
            }
            for (int i=0; i<numVals; i++, offset++) {
                for (int k=0; k<numBands; k++) {
                    D[k][i] = this.xform[k][offset];
                }
            }

            next = 0;
            for (int j = 0; j < scale; j++) {
                for (int i = 0; i < scale; i++, next++) {
                    for (int k = 0; k < numBands; k++) {
                        int a = A[k][next] + H[k][next] + V[k][next] + D[k][next];
                        int b = A[k][next] - H[k][next] + V[k][next] - D[k][next];
                        int c = A[k][next] + H[k][next] - V[k][next] - D[k][next];
                        int d = A[k][next] - H[k][next] - V[k][next] + D[k][next];
                        imageData[k][2*j*resolution + (i*2)] =  a;
                        imageData[k][2*j*resolution + (i*2) + 1] = b;
                        imageData[k][2*j*resolution + resolution + (i*2)] = c;
                        imageData[k][2*j*resolution + resolution + (i*2) + 1] = d;
                    }
                }
            }

            scale *= 2;
        } while (scale < resolution);

        // Copy to bytes and clamp to byte-range...
        for (int j = 0; j < resolution; j++) {
            for (int i = 0; i < resolution; i++) {
                for (int k = 0; k < numBands; k++) {
                    imageBytes[k][j*resolution+i] = (byte) Math.max(0, Math.min(255, imageData[k][j*resolution+i]));
                }
            }
        }

        // Finally, construct a BufferedImage...
        BandedSampleModel sm = new BandedSampleModel(DataBuffer.TYPE_BYTE, resolution, resolution, numBands);
        DataBufferByte dataBuff = new DataBufferByte(imageBytes, imageBytes[0].length);
        WritableRaster rast = Raster.createWritableRaster(sm, dataBuff, new Point(0, 0));
        int imageType = (numBands == 1) ? BufferedImage.TYPE_BYTE_GRAY : BufferedImage.TYPE_INT_RGB;
        BufferedImage image = new BufferedImage(resolution, resolution, imageType);
        image.getRaster().setRect(rast);

        return image;
    }

    /**
     * Saves this wavelet encoding to the given File.
     *
     * @param file
     * @throws IOException
     */
    public void save(File file) throws IOException {
        DataOutputStream out = new DataOutputStream(new FileOutputStream(file));
        out.writeInt(this.resolutionX);
        out.writeInt(this.resolutionY);
        out.writeInt(this.imageType.getTag());
        out.writeInt(this.xform.length);
        for (int k=0; k<this.xform.length; k++)
            out.write(this.xform[k]);
        out.close();
    }

    /**
     * Returns the resolution of this wavelet encoding.
     *
     * @return resolution
     */
    public int getResolutionX() { return this.resolutionX; }

    /**
     * Returns the resolution of this wavelet encoding.
     *
     * @return resolution
     */
    public int getResolutionY() { return this.resolutionY; }

    /**
     * Convenience method for testing is a value is a power of two.
     *
     * @param value
     * @return
     */
    public static boolean isPowerOfTwo(int value) {
        return (value == nearestPowerOfTwo(value)) ? true : false;
    }

    /**
     * Returns a resolution value that is the nearest power of 2 greater than or equal to the given
     * value.
     *
     * @param resolution
     * @return power of two resolution
     */
    public static int nearestPowerOfTwo(int resolution) {
        int power = (int) Math.ceil(Math.log(resolution) / Math.log(2.));
        return (int) Math.pow(2., power);
    }

    /**
     * Convenience method to compute the log-2 of a value.
     *
     * @param value
     * @return
     */
    public static double logBase2(double value) {
        return Math.log(value) / Math.log(2.);
    }

    public enum EncodingType {
        GRAY_SCALE(0x67726179),  // ascii "gray"
        COLOR_RGB(0x72676220);   // ascii "rgb "

        private EncodingType(int tag) { this.tag = tag; }
        public int getTag() { return this.tag; }
        private int tag;
    }

    private int resolutionX;
    private int resolutionY;
    private EncodingType imageType;
    private byte[][] xform;

}
