#!/usr/bin/env python
# $URL: https://pypng.googlecode.com/svn/trunk/code/pipstack $
# $Rev: 159 $

# pipstack
# Combine input PNG files into a multi-channel output PNG.

"""
pipstack file1.png [file2.png ...]
pipstack can be used to combine 3 greyscale PNG files into a colour, RGB,
PNG file.  In fact it is slightly more general than that.  The number of
channels in the output PNG is equal to the sum of the numbers of
channels in the input images.  It is an error if this sum exceeds 4 (the
maximum number of channels in a PNG image is 4, for an RGBA image).  The
output colour model corresponds to the number of channels: 1 -
greyscale; 2 - greyscale+alpha; 3 - RGB; 4 - RGB+alpha.

In this way it is possible to combine 3 greyscale PNG files into an RGB
PNG (a common expected use) as well as more esoteric options: rgb.png +
grey.png = rgba.png; grey.png + grey.png = greyalpha.png.

gAMA, pHYs, iCCP, sRGB, tIME, any other chunks.
"""

class Error(Exception):
    pass

def stack(out, inp):
    """Stack the input PNG files into a single output PNG."""

    from array import array
    import itertools
    # Local module
    import png

    if len(inp) < 1:
        raise Error("Required input is missing.")

    l = map(png.Reader, inp)
    # Let data be a list of (pixel,info) pairs.
    data = map(lambda p: p.asDirect()[2:], l)
    totalchannels = sum(map(lambda x: x[1]['planes'], data))

    if not (0 < totalchannels <= 4):
        raise Error("Too many channels in input.")
    alpha = totalchannels in (2,4)
    greyscale = totalchannels in (1,2)
    bitdepth = []
    for b in map(lambda x: x[1]['bitdepth'], data):
        try:
            if b == int(b):
                bitdepth.append(b)
                continue
        except (TypeError, ValueError):
            pass
        # Assume a tuple.
        bitdepth += b
    # Currently, fail unless all bitdepths equal.
    if len(set(bitdepth)) > 1:
        raise NotImplemented("Cannot cope when bitdepths differ - sorry!")
    bitdepth = bitdepth[0]
    arraytype = 'BH'[bitdepth > 8]
    size = map(lambda x: x[1]['size'], data)
    # Currently, fail unless all images same size.
    if len(set(size)) > 1:
        raise NotImplemented("Cannot cope when sizes differ - sorry!")
    size = size[0]
    # Values per row
    vpr = totalchannels * size[0]
    def iterstack():
        # the izip call creates an iterator that yields the next row
        # from all the input images combined into a tuple.
        for irow in itertools.izip(*map(lambda x: x[0], data)):
            row = array(arraytype, [0]*vpr)
            # output channel
            och = 0
            for i,arow in enumerate(irow):
                # ensure incoming row is an array
                arow = array(arraytype, arow)
                n = data[i][1]['planes']
                for j in range(n):
                    row[och::totalchannels] = arow[j::n]
                    och += 1
            yield row
    w = png.Writer(size[0], size[1],
      greyscale=greyscale, alpha=alpha, bitdepth=bitdepth)
    w.write(out, iterstack())


def main(argv=None):
    import sys

    if argv is None:
        argv = sys.argv
    argv = argv[1:]
    arg = argv[:]
    return stack(sys.stdout, arg)


if __name__ == '__main__':
    main()
