#!/usr/bin/python

import oam

epsg4326_wkt = 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]'

def parse_options():
    parser = oam.option_parser()
    #parser.add_option("-l", "--layer", dest="layer", type="int", help="Layer ID")
    parser.add_option("-a", "--archive", dest="archive", action="store_true", default=False, help="Include archive images")
    parser.add_option("-r", "--resolution", dest="resolution", default="min", help="Resolution strategy (min|max|avg)")
    parser.add_option("-s", "--size", dest="size", help="Output VRT size in pixels (width,height)")
    (opts, args) = parser.parse_args()
    if not len(args):
        print parser.print_help()
        sys.exit(0)
    opts.output = args.pop(0)
    opts.bbox = oam.parse_bbox(args)
    return opts

def determine_image_size(strategy, images, bbox):
    if strategy == "min":
       target_resolution = min([img.px_width for img in images])
    elif strategy == "max":
       target_resolution = max([img.px_width for img in images])
    else: # avg
       target_resolution = sum([img.px_width for img in images]) / len(images)
    width = int((bbox[2] - bbox[0]) / target_resolution + 0.5)
    height = int((bbox[3] - bbox[1]) / target_resolution + 0.5)
    return width, height

def write_vrt_source(output, target, source, band):
    overlap = source.intersection(target.bbox)
    if not overlap:
        print "ERROR: No overlap between", target.path, "and", source.path
        return
    source_win = source.window(overlap)
    target_win = target.window(overlap)
    if not source_win or not target_win:
        print "ERROR: Source window:", source_win, "Target window:", target_win
        return
    band_idx, data_type, block_width, block_height = source.bands[band]
    output.write('\t\t<ComplexSource>\n')
    output.write(('\t\t\t<SourceFilename relativeToVRT="0">/vsicurl/%s' +
        '</SourceFilename>\n') % source.path)
    output.write('\t\t\t<SourceBand>%i</SourceBand>\n' % band_idx)
    ### TODO: make a config option for this.
    output.write('\t\t\t<NODATA>0</NODATA>\n')
    output.write(('\t\t\t<SourceProperties RasterXSize="%i" RasterYSize="%i"' +
                ' DataType="%s" BlockXSize="%i" BlockYSize="%i"/>\n')
                % (source.width, source.height, data_type, block_width, block_height))
    output.write('\t\t\t<SrcRect xOff="%i" yOff="%i" xSize="%i" ySize="%i"/>\n' % source_win)
    output.write('\t\t\t<DstRect xOff="%i" yOff="%i" xSize="%i" ySize="%i"/>\n' % target_win)
    output.write('\t\t</ComplexSource>\n')

def write_vrt(target, sources):
    """
    if source.px_width != target.px_width or source.px_height != target.px_height:
        print ("All files must have the same scale; %s does not" % source.path)
        sys.exit(1)

    if fi.geotransform[2] != 0 or fi.geotransform[4] != 0:
        print ("No file must be rotated; %s is" % fi.filename)
        sys.exit(1)

    if fi.projection != projection:
        print ("All files must be in the same projection; %s is not" \
            % fi.filename)
        sys.exit(1)
    """
    output = open(target.path, 'w')
    output.write('<VRTDataset rasterXSize="%i" rasterYSize="%i">\n'
        % (target.width, target.height))
    output.write('\t<GeoTransform>%24.16f, %24.16f, %24.16f, %24.16f, %24.16f, %24.16f</GeoTransform>\n'
        % target.transform)

    # output.write('\t<SRS>%s</SRS>\n' % target.crs)
    output.write('\t<SRS>%s</SRS>\n' % epsg4326_wkt)

    for band, interp in enumerate(('Red', 'Green', 'Blue')):
        output.write('\t<VRTRasterBand dataType="Byte" band="%i">\n' % (band+1))
        output.write('\t\t<ColorInterp>%s</ColorInterp>\n' % interp)
        #output.write('\t\t<NoDataValue>0</NoDataValue>\n')
        for source in sources:
            write_vrt_source(output, target, source, interp)
        output.write('\t</VRTRasterBand>\n')

    output.write('</VRTDataset>\n')

def main(opts):
    client = oam.build_client(opts)
    args = {}
    if opts.archive: args["archive"] = "true"
    images = client.images_by_bbox(opts.bbox, **args)
    if not images:
        raise Exception("No images found for that bounding box!")
    # verify that all images have the same SRS
    if len(set(img.crs for img in images)) > 1:
        raise Exception("Not all images have the same CRS!")
    # determine image size based on resolution strategy
    if opts.size:
        width, height = map(int, opts.size.split(","))
    else:
        width, height = determine_image_size(opts.resolution, images, opts.bbox)
    target = oam.Image(opts.output, opts.bbox, width=width, height=height, crs=images[0].crs) 
    # TODO: do something intelligent about masking/alpha band etc?
    write_vrt(target, images)

if __name__ == "__main__":
    main(parse_options())
