#! /usr/bin/env python
# Author: Martin C. Frith 2019
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import print_function

import gzip
import logging
import optparse
import signal
import subprocess
import sys

def openFile(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def fastxInput(lines):
    s = []
    for i in lines:
        if i[0] == ">" and s and s[0][0] == ">" or i[0] == "@" and len(s) == 4:
            yield s
            s = []
        s.append(i)
    if s:
        yield s

def seqGroupsFromFile(lines):
    seqGroups = []
    seqNamesToGroupNums = {}
    for line in lines:
        if line.startswith("# PART "):
            groupName = line.split()[2]
            seqGroups.append((groupName, []))
            sLineCount = 0
        elif line[0] == "s":
            sLineCount += 1
            if sLineCount % 2 == 0:
                n = line.split()[1]
                seqName = n[:-1]
                seqNamesToGroupNums[seqName] = len(seqGroups) - 1
    return seqGroups, seqNamesToGroupNums

def main(opts, args):
    logLevel = logging.INFO if opts.verbose else logging.WARNING
    logging.basicConfig(format="%(filename)s: %(message)s", level=logLevel)

    seqGroups, seqNamesToGroupNums = seqGroupsFromFile(openFile(args[-1]))

    for seqLines in fastxInput(openFile(args[0])):
        seqName = seqLines[0][1:].split()[0]
        if seqName in seqNamesToGroupNums:
            if len(args) == 2:
                print(*seqLines, sep="", end="")
                continue
            groupNum = seqNamesToGroupNums[seqName]
            groupName, seqs = seqGroups[groupNum]
            seqs.append(seqLines)

    if len(args) == 2:
        return

    for groupName, seqs in seqGroups:
        mergedSequenceName = groupName.replace("group", "merge")
        cmd = ["lamassemble", "-n" + mergedSequenceName]
        cmd.append("-g" + str(opts.gap_max))
        cmd.append("-s" + str(opts.seq_min))
        cmd.append("-p" + str(opts.prob))
        if opts.verbose:
            cmd.append("-" + "v" * opts.verbose)
        cmd.append("-P" + str(opts.P))
        cmd.append("-W" + str(opts.W))
        cmd.append("-m" + str(opts.m))
        cmd.append("-z" + str(opts.z))
        cmd += [args[1], "-"]
        logging.info(" ".join(cmd))
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE,
                                universal_newlines=True)
        for i in seqs:
            for j in i:
                proc.stdin.write(j)
        proc.stdin.close()
        proc.wait()

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    usage = """%prog [options] sequences.fx last-train.out groups.maf > out.fa
   or: %prog sequences.fx groups.maf > unmerged-sequences.fx"""
    descr = "Merge DNA sequences of each group into a consensus sequence."
    op = optparse.OptionParser(usage=usage, description=descr)
    op.add_option("-g", "--gap-max", metavar="G", type="float", default=50,
                  help="use alignment columns with <= G% gaps "
                  "(default=%default)")
    op.add_option("-s", "--seq-min", metavar="S", default="1", help="omit "
                  "consensus flanks with < S sequences (default=%default)")
    op.add_option("-p", "--prob", metavar="P", type="float", default=0.002,
                  help="use pairwise restrictions with error probability <= P "
                  "(default=%default)")
    op.add_option("-v", "--verbose", action="count", default=0,
                  help="show progress messages")

    og = optparse.OptionGroup(op, "LAST options")
    og.add_option("-P", type="int", default=1,
                  help="number of parallel threads (default=%default)")
    og.add_option("-W", type="int", default=19, help="use minimum positions "
                  "in length-W windows (default=%default)")
    og.add_option("-m", type="int", default=5, help=
                  "max initial matches per query position (default=%default)")
    og.add_option("-z", type="int", default=30,
                  help="max gap length (default=%default)")
    op.add_option_group(og)

    opts, args = op.parse_args()
    if len(args) in (2, 3):
        main(opts, args)
    else:
        op.print_help()
