#!/usr/bin/env python
#
# histogram-matching.py
#
# generate a tone curve for RawTherapee by performing histogram matching
# between the embedded thumbnail and the rendering of RT using the neutral
# profile
#
# Copyright 2017 Alberto Griggio
#
# histogram-matching.py is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# The program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
import os, sys
import argparse
import subprocess
from PIL import Image
import tempfile
import shutil

DCRAW = 'dcraw'
RAWTHERAPEE = 'rawtherapee-cli'


def log(fmt, *args):
    sys.stderr.write(fmt % args)
    sys.stderr.flush()
    

def get_source(rawimage):
    log('extracting embedded thumbnail from %s\n', rawimage)
    fd, name = tempfile.mkstemp('.jpg')
    out = os.fdopen(fd, 'w')
    p = subprocess.Popen([DCRAW, '-c', '-e', rawimage], stdout=out)
    p.communicate()
    out.close()
    ret = Image.open(name)
    os.unlink(name)
    return ret


target_profile = """
[Version]
AppVersion=5.4
Version=329

[Color Management]
InputProfile=(cameraICC)
ToneCurve=false
ApplyLookTable=false
ApplyBaselineExposureOffset=false
ApplyHueSatMap=true
DCPIlluminant=0
WorkingProfile=ProPhoto
OutputProfile=RT_sRGB

[RAW Bayer]
Method=fast

[RAW X-Trans]
Method=fast

[Resize]
Enabled=true
Width=900
Height=900
"""

def get_target(rawimage):
    log('generating neutral jpg from %s\n', rawimage)
    dirname = tempfile.mkdtemp()
    name = os.path.join(dirname, 'target.jpg')
    pp3 = os.path.join(dirname, 'target.pp3')
    with open(pp3, 'w') as f:
        f.write(target_profile)
    p = subprocess.Popen([RAWTHERAPEE, '-q', '-f', '-p', pp3, '-o', name,
                          '-Y', '-c', rawimage],
                         stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    p.communicate()
    ret = Image.open(name)
    shutil.rmtree(dirname)
    return ret


def get_histogram(image):
    image = image.convert("L")
    return image.histogram()


def get_cdf(histogram):
    ret = []
    prev = 0
    for v in histogram:
        prev += v
        ret.append(prev)
    assert len(ret) == 256
    return ret


def find_match(j, val, cdf):
    if cdf[j] <= val and j+1 < len(cdf) and cdf[j+1] >= val:
        return j if val - cdf[j] <= cdf[j+1] - val else j+1
    diff = 0
    for i, c in enumerate(cdf):
        if c == val:
            return i
        elif c > val:
            return i if c - val < diff else i-1
        diff = val - c
    return 255
        

def histogram_matching(reference, target):
    log('performing histogram matching\n')
    rhist = get_histogram(reference)
    thist = get_histogram(target)
    mapping = []
    rcdf = get_cdf(rhist)
    tcdf = get_cdf(thist)
    for i in range(256):
        j = find_match(i, tcdf[i], rcdf)
        mapping.append(j)
    return mapping


def hist2curve(m, npoints):
    # first search for the halfway point
    idx = 0
    for i, v in enumerate(m):
        if i > 0 and v >= i:
            idx = i
            break
    assert 0 < idx < 255
    f = float(idx) / float(256-idx)
    p = max(int(npoints * f), 2)
    step1 = max(idx / p, 1)
    step2 = max((255-idx) / (npoints - p), 1)
    for i in range(255):
        x = float(i)/255.0
        y = max(0, min(float(m[i]/255.0), 1))
        if y > 0:
            yield (x, y)
            break
    for i in range(step1, idx-step1, step1):
        x = float(i)/255.0
        y = max(0, min(float(m[i]/255.0), 1))
        if 0.01 < y < 0.99:
            yield (x, y)
    for i in range(idx, 256-step2, step2):
        x = float(i)/255.0
        y = max(0, min(float(m[i]/255.0), 1))
        if 0.01 < y < 0.99:
            yield (x, y)
    for i in range(255, 0, -1):
        x = float(i)/255.0
        y = max(0, min(float(m[i]/255.0), 1))
        if y < 1:
            yield (x, y)
            break


def make_pp3(mapping, npoints):
    out = []
    pr = out.append
    pr('[Version]\nAppVersion=5.4\nVersion=329\n')
    pr('[Exposure]\nAuto=false\nCurveMode=FilmLike')
    pr('Curve=1;%s;' % ';'.join('%s;%s' % t
                                for t in hist2curve(mapping, npoints)))
    pr('')
    return '\n'.join(out)


def getsize(reference, target, maxdim):
    s1 = (reference.width, reference.height)
    s2 = (target.width, target.height)
    if s1 < s2:
        s1 = s2
    f = float(maxdim) / float(max(s1[0], s1[1]))
    return int(s1[0] * f), int(s1[1] * f)


def getopts():
    global DCRAW, RAWTHERAPEE
    p = argparse.ArgumentParser()
    p.add_argument('rawfile')
    p.add_argument('-o', '--output')
    p.add_argument('--dcraw')
    p.add_argument('--rawtherapee')
    p.add_argument('-n', '--num-points', type=int, default=5)
    p.add_argument('-s', '--size', type=int, default=900)
    opts = p.parse_args()
    if opts.dcraw:
        DCRAW = opts.dcraw
    if opts.rawtherapee:
        RAWTHERAPEE = opts.rawtherapee
    return opts


def main():
    opts = getopts()
    reference = get_source(opts.rawfile)
    target = get_target(opts.rawfile)
    size = getsize(reference, target, opts.size)
    reference = reference.resize(size)
    target = target.resize(size)
    mapping = histogram_matching(reference, target)
    pp3 = make_pp3(mapping, opts.num_points)
    if opts.output:
        with open(opts.output, 'w') as out:
            out.write(pp3)
    else:
        sys.stdout.write(pp3)


if __name__ == '__main__':
    main()
