#!/usr/bin/python
#emacs: -*- mode: python-mode; py-indent-offset: 4; tab-width: 4; indent-tabs-mode: nil -*-
#ex: set sts=4 ts=4 sw=4 noet:
#------------------------- =+- Python script -+= -------------------------
"""Simple script to use produce lateralized Harvard-Oxford cortical atlas

  Yaroslav Halchenko                                            Dartmouth
  web:     http://www.onerussian.com                              College
  e-mail:  yoh@onerussian.com                              ICQ#: 60653192

 COPYRIGHT: Yaroslav Halchenko 2012-2013

 LICENSE: MIT

  Permission is hereby granted, free of charge, to any person obtaining a copy
  of this software and associated documentation files (the "Software"), to deal
  in the Software without restriction, including without limitation the rights
  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  copies of the Software, and to permit persons to whom the Software is
  furnished to do so, subject to the following conditions:

  The above copyright notice and this permission notice shall be included in
  all copies or substantial portions of the Software.

  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  THE SOFTWARE.
"""
#-----------------\____________________________________/------------------

__author__ = 'Yaroslav Halchenko'
__copyright__ = 'Copyright (c) 2012 Yaroslav Halchenko'
__license__ = 'MIT'

import os, re
import numpy as np
import nibabel as nib                   # To load/save nifti volumes
from lxml import etree                  # To load/save .xml atlases

verbosity = 3

def verbose(level, s, fmt={}):
    if level <= verbosity:
        print " "*max(0, level-1), s % fmt

def load_atlas(path):
    # path = '/usr/share/data/harvard-oxford-atlases/HarvardOxford-Subcortical.xml'
    dirname = os.path.dirname(path)

    # XML pieces of interest, let them have _ suffix
    root_ = etree.parse(path)
    imagefiles_ = root_.findall('header/images/imagefile')
    labels_ = root_.findall('data/label')

    # deduce imagefiles
    imagefiles = [os.path.join(dirname, f_.text.lstrip('/'))
                  for f_ in imagefiles_]

    maps = {}
    labels = [l.text for l in labels_]
    nifti_hdrs = {}
    for vol_filename in  imagefiles:
        res = int(re.sub('.*-([0-9]+)mm', r'\1', vol_filename))
        ni = nib.load(vol_filename + '.nii.gz')
        maps[res] = ni.get_data()
        nifti_hdrs[res] = ni.get_header()
    return maps, labels, nifti_hdrs

def get_lateralities(maps, labels, res):
    lateralities = {}
    for lat in ('Left', 'Right'):
        # create laterality probability maps based on subcort atlas
        lat_labels = [i for i,c in enumerate(labels) if c.startswith(lat)]
        verbose(2, "%d labels for %s", (len(lat_labels), lat))
        lat_maps = maps[..., lat_labels]
        lat_map = lat_maps.sum(axis=-1) # total probability map
        assert(np.max(lat_map) <= 102)
        # we need to clip it -- at least 1% is above, probably due to rounding
        lat_map[lat_map>100] = 100
        lat_map = lat_map / 100.        # convert to ratio from %
        lateralities[lat] = lat_map

    return lateralities

def normalize_lateralities(lateralities):
    # done outside of get_lateralities so we could still verify
    # correctness of lateralities if we decide to
    l1, l2 = lateralities.values()
    # total ratio for every voxel
    l12 = l1 + l2
    nz = np.nonzero(l12)
    for l in (l1, l2):
        assert(np.all(l>=0))
        l[nz] /= l12[nz]
    # just to be sure that all are close to 100% (or 0%) ;)
    l12unique100 = np.unique(np.round((l1 + l2)*100)).astype(int)
    assert(set(l12unique100) == set([0, 100]))

def lateralize_atlas(maps, labels, divisions):
    """Combine original atlas and laterality information
    while correspondingly growing number of labels etc
    """
    # combine maps['cort'] with lateralities and generate lateralized
    # list of labels

    orig_shape = maps.shape
    assert(orig_shape[-1] == len(labels))
    res_labels = []
    res_maps = np.zeros(orig_shape[:3] + (orig_shape[3]*len(divisions),),
                        dtype=maps.dtype)
    i = 0
    for a, m in zip(labels, np.rollaxis(maps, 3,0)):
        verbose(3, "area %s ..." % a)
        for d in sorted(divisions.keys()):
            res_maps[..., i] = m[...] * divisions[d]
            res_labels.append('%s %s' % (d, a))
            i += 1
    assert(i == res_maps.shape[-1])
    return res_maps, res_labels

def get_header(hdr, data, colormap=None):
    header={'cal_min': 0,
            'cal_max': np.max(data)}
    if colormap:
        header['aux_file'] = colormap
    nifti_hdr = hdr.copy()
    for f,v in header.iteritems():
        nifti_hdr[f] = v
    return nifti_hdr

def make_atlas(maps,                    #  dict of resolution: map
               labels,
               fullname,
               name=None,
               shortname=None,
               imagefilename=None,
               topdir='.',
               subdir='',             # where to dump volumes
               type_="Probabilistic",
               nifti_hdrs=None,
               images_thr=[],
               colormap_thr='MGH-Subcortical'):
    # first generate an xml file
    resolutions = sorted(maps.keys())[::-1] # my guess that we better
                                            # stay more consistent and
                                            # first list low-res one
                                            # since that is the one
                                            # which would be used for
                                            # the coordinates
    SE = etree.SubElement
    root = etree.Element('atlas', attrib=dict(version='1.0'))
    header = SE(root, 'header')
    SE(header, 'name').text = fullname
    SE(header, 'shortname').text = shortname
    SE(header, 'type').text = type_
    if subdir is not None:
        subdir_full = os.path.join(topdir, subdir)
        if not os.path.exists(subdir_full):
            os.makedirs(subdir_full)
        subdir = '/%s/' % subdir

    imagefilename = imagefilename or name

    typ = {"Probabilistic": "prob"}[type_] # abbreviated version
    for res in resolutions:
        images = SE(header, "images")
        prob_file = '%(subdir)s%(imagefilename)s-%(typ)s-%(res)smm' % locals()
        SE(images, 'imagefile').text = prob_file

        # store the actual probability map
        map_ni = nib.Nifti1Image(
            maps[res], None,
            get_header(nifti_hdrs[res], maps[res]))
        map_ni.to_filename(os.path.join(topdir, prob_file.lstrip('/')+'.nii.gz'))

        # Generate maxprob entries/files
        map_argmax = np.argmax(maps[res], axis=-1)
        map_max = np.max(maps[res], axis=-1)
        for thr in images_thr:
            if False:
                # Compute/dump such a map
                map_thr_ = maps[res].copy()
                # this blows our memory limits on head1 on 1mm atlas
                # Theoretically we could avoid copy, get indexes of maxes
                # ones, and then assign that max prob in that single volume, for further
                # thresholding.  Otherwise now consumes 10G in peak
                map_thr_[map_thr_ < thr ] = 0
                map_thr = np.argmax(map_thr_, axis=-1)
                # Since 0 is not different from 0 -- offset by 1
                # but only where we had some in any volume
                map_thr[np.sum(map_thr_, axis=-1)>0] += 1
            # now we do it more efficiently without bloating memory
            map_thr = np.zeros(map_argmax.shape, dtype=int)
            # tiny number solely to maintain >= for non-0 thr, and have > for 0 thr
            over = np.where(map_max >= (thr if thr else 0.000001))
            # assign index for the maximum (+1) where that max passed
            # the threshold
            map_thr[over] = map_argmax[over] + 1
            """original subcortical atlas is even fancier -- there is a magical
            /usr/share/data/harvard-oxford-atlases/HarvardOxford/labels providing
            relabeling"""
            assert(map_thr[0, 0, 0] == 0) # should not have any label outside of the brain
            # store maxprob volume
            thresh_file = '%(subdir)s%(imagefilename)s-max%(typ)s-thr%(thr)d-%(res)smm' \
                          % locals()
            SE(images, 'summaryimagefile').text = thresh_file

            map_ni = nib.Nifti1Image(
                map_thr, None,
                get_header(nifti_hdrs[res], map_thr, colormap_thr))
            map_ni.to_filename(os.path.join(topdir, thresh_file.lstrip('/')+'.nii.gz'))


    data = SE(root, 'data')
    map_ = maps[resolutions[0]]

    # now go through all the labels
    for index, l in enumerate(labels):
        # figure out representative x,y,z
        # nz = np.asanyarray(np.nonzero(map_[..., index]))
        # let's take only the 'maximum' voxels
        max_prob = np.max(map_[..., index])
        nz = np.array(np.where(map_[..., index] >= max_prob))
        # figure out center and closest point within the ROI actually
        center = np.mean(nz, axis=1)
        dnz = nz - center[:, None]
        distances = np.sum(dnz * dnz, axis=0)
        x, y, z = nz[:, np.argmin(distances)]
        l_ = SE(data, 'label')
        # to place all the attributes in defined order we can't just pass
        # to the constructor via dict and we have no OrderedDict in 2.6
        for a in ['index', 'x', 'y', 'z']:
            l_.set(a, str(locals()[a]))
        l_.text = l

    # dump XML atlas
    filename = os.path.join(topdir, name + '.xml')
    open(filename, 'w').write(
        etree.tostring(root, pretty_print=True, encoding='ISO-8859-1'))
    return filename


if __name__ == '__main__':

    topdir = 'data/atlases'

    verbose(1, "Loading original atlases")
    maps_cor, labels_cor, nifti_hdrs_cor = load_atlas(
                 topdir + '/HarvardOxford-Cortical.xml')
    maps_sub, labels_sub, nifti_hdrs_sub = load_atlas(
                 topdir + '/HarvardOxford-Subcortical.xml')

    verbose(1, "Lateralizing the atlas")
    maps_corl, labels_corl = {}, {}
    for r in maps_sub.keys():
        lateralities = get_lateralities(maps_sub[r], labels_sub, r)

        # to carry only laterality probability, we should scale each
        # non-degenerate voxel to add up to 1.0
        normalize_lateralities(lateralities)

        maps_corl[r], labels_corl = \
                      lateralize_atlas(maps_cor[r], labels_cor, lateralities)

        # L+R should add up to the not-lateralized one.  Because we
        # are dealing with % in int8, we should allow for few percent
        # difference since some regions might have simply vanished if
        # their original probability was low
        assert(np.max(maps_cor[r]
                      - (maps_corl[r][...,0::2] + maps_corl[r][...,1::2])) <= 4)

    verbose(1, "Storing the atlas")

    cortl_xml = make_atlas(maps_corl,
                         labels_corl,
                         "Harvard-Oxford Cortical Structural Atlas (Lateralized)",
                         name="HarvardOxford-Cortical-Lateralized",
                         shortname="HOCPAL",
                         imagefilename="HarvardOxford-cortl",
                         topdir=topdir,
                         subdir='HarvardOxford',
                         images_thr=[0, 25, 50],
                         nifti_hdrs=nifti_hdrs_cor,
                         )

# if True:
def test_flip_midplane():
    # Generate two 4d volumes matching left and right labels, with right labels flipped around 41
    # nah -- no time
    labels_sub_sorted = sorted(list(enumerate(labels_sub)),
                               cmp=lambda x,y: cmp(x[1], y[1]))

    def get_argmax_map(map_, indexes):
        # 0 for brain-stem so we don't need to add 1 to argmax later on
        idx = [7] + [x[0] for x in indexes]   # get only indexes without names
        # let's get through temp storage
        map_idx = map_[..., idx]
        map_argmax = np.argmax(map_idx, axis=-1)
        # silly yarik forgotten how to erect needed indices so let's
        # just flatten things out
        map_idx_ = map_idx.reshape((-1, len(idx)))
        n = len(map_idx_)
        map_argmax_and_prob = map_argmax.astype(float)
        map_argmax_and_prob_ = map_argmax_and_prob.reshape((-1,))
        map_argmax_and_prob_ += map_idx_[(np.arange(n), map_argmax.reshape((-1,)))]/100.
        return map_argmax_and_prob

    lmap = get_argmax_map(maps_sub[2], labels_sub_sorted[1:11])
    rmap = get_argmax_map(maps_sub[2], labels_sub_sorted[11:])
    # swap rmap around mid-x
    rmap_rev = rmap[::-1]

    hdr = nifti_hdrs_sub[2]
    nib.Nifti1Image(lmap, None, hdr).to_filename('/tmp/test-left.nii.gz')
    nib.Nifti1Image(rmap_rev, None, hdr).to_filename('/tmp/test-right_rev.nii.gz')


# if True:
def test_plot_x(y=64, z=20):
    import pylab as pl
    y = 64
    z = 20
    pl.plot(lateralities_['Left'][:, y, z])
    pl.plot(lateralities_['Right'][:, y, z])

    pl.plot(lateralities['Left'][:, y, z], '--')
    pl.plot(lateralities['Right'][:, y, z], '--')

    pl.hlines(0.50, 0, 91, linestyle='dashed')
    pl.vlines(45, 0, 1, linestyle='dashed')
    pl.show()

# if True:
def test_store_lateralities():
    ni_out = nib.Nifti1Image(
        np.rollaxis(np.array(lateralities.values()), 0, 4),
        None, nifti_hdrs_cor[2])
    ni_out.to_filename('/tmp/112_.nii.gz')
