#! /usr/bin/env python
from __future__ import print_function

import os
import sys
import glob
from time import time
import itertools
from itertools import islice, chain
from struct import *
import shutil

import argparse
import errno
import math

# import pickle
import dill as pickle 
import gffutils
import pysam

import multiprocessing as mp
from multiprocessing import Pool
from collections import defaultdict
# from collections import OrderedDict

# from modules import create_splice_graph as splice_graph
# from modules import graph_chainer 

from modules import create_augmented_gene as augmented_gene 
from modules import mem_wrapper 
from modules import colinear_solver 
from modules import help_functions
from modules import classify_read_with_mams
from modules import classify_alignment2
from modules import sam_output
from modules import align
from modules import prefilter_genomic_reads


def load_reference(args):
    refs = {acc : seq for acc, (seq, _) in help_functions.readfq(open(args.ref,"r"))}
    refs_lengths = { acc : len(seq) for acc, seq in refs.items()} 
    return refs, refs_lengths

def prep_splicing(args, refs_lengths):
    if args.index:
        index_folder = args.index
        help_functions.mkdir_p(index_folder)
    else:
        index_folder = args.outfolder

    database = os.path.join(index_folder,'database.db')

    if os.path.isfile(database):
        print("Database found in directory using this one.")
        print("If you want to recreate the database, please remove the file: {0}".format(database))
        print()
        db = gffutils.FeatureDB(database, keep_order=True)
        # sys.exit()
    elif not args.disable_infer:
        db = gffutils.create_db(args.gtf, dbfn=database, force=True, keep_order=True, merge_strategy='merge', 
                                sort_attribute_values=True)
        db = gffutils.FeatureDB(database, keep_order=True)
    else:
        db = gffutils.create_db(args.gtf, dbfn=database, force=True, keep_order=True, merge_strategy='merge', 
                                sort_attribute_values=True, disable_infer_genes=True, disable_infer_transcripts=True)
        db = gffutils.FeatureDB(database, keep_order=True)

    # segment_to_ref, parts_to_segments, splices_to_transcripts, \
    # transcripts_to_splices, all_splice_pairs_annotations, \
    # all_splice_sites_annotations, segment_id_to_choordinates, \
    # segment_to_gene, gene_to_small_segments, flank_choordinates, \
    # max_intron_chr, exon_choordinates_to_id, chr_to_id, id_to_chr, tiling_structures = augmented_gene.create_graph_from_exon_parts(db, args.flank_size, args.small_exon_threshold, args.min_segm, refs_lengths)

    segment_to_ref, parts_to_segments, splices_to_transcripts, \
    transcripts_to_splices, all_splice_pairs_annotations, \
    all_splice_sites_annotations, segment_id_to_choordinates, \
    segment_to_gene, gene_to_small_segments, flank_choordinates, \
    max_intron_chr, exon_choordinates_to_id, chr_to_id, id_to_chr = augmented_gene.create_graph_from_exon_parts(db, args.flank_size, args.small_exon_threshold, args.min_segm, refs_lengths)


    # dump to pickle here! Both graph and reference seqs
    # help_functions.pickle_dump(args, genes_to_ref, 'genes_to_ref.pickle')
    help_functions.pickle_dump(index_folder, segment_to_ref, 'segment_to_ref.pickle')
    help_functions.pickle_dump(index_folder, splices_to_transcripts, 'splices_to_transcripts.pickle')
    help_functions.pickle_dump(index_folder, transcripts_to_splices, 'transcripts_to_splices.pickle')
    help_functions.pickle_dump(index_folder, parts_to_segments, 'parts_to_segments.pickle')
    help_functions.pickle_dump(index_folder, all_splice_pairs_annotations, 'all_splice_pairs_annotations.pickle')
    help_functions.pickle_dump(index_folder, all_splice_sites_annotations, 'all_splice_sites_annotations.pickle')
    help_functions.pickle_dump(index_folder, segment_id_to_choordinates, 'segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(index_folder, segment_to_gene, 'segment_to_gene.pickle')
    help_functions.pickle_dump(index_folder, gene_to_small_segments, 'gene_to_small_segments.pickle')
    help_functions.pickle_dump(index_folder, flank_choordinates, 'flank_choordinates.pickle')
    help_functions.pickle_dump(index_folder, max_intron_chr, 'max_intron_chr.pickle')
    help_functions.pickle_dump(index_folder, exon_choordinates_to_id, 'exon_choordinates_to_id.pickle')
    help_functions.pickle_dump(index_folder, chr_to_id, 'chr_to_id.pickle')
    help_functions.pickle_dump(index_folder, id_to_chr, 'id_to_chr.pickle')

    # tiling_segment_id_to_choordinates, tiling_segment_to_gene, tiling_segment_to_ref, tiling_parts_to_segments, tiling_gene_to_small_segments = tiling_structures # unpacking tiling structures
    # help_functions.pickle_dump(args, tiling_segment_id_to_choordinates, 'tiling_segment_id_to_choordinates.pickle')
    # help_functions.pickle_dump(args, tiling_segment_to_gene, 'tiling_segment_to_gene.pickle')
    # help_functions.pickle_dump(args, tiling_segment_to_ref, 'tiling_segment_to_ref.pickle')
    # help_functions.pickle_dump(args, tiling_parts_to_segments, 'tiling_parts_to_segments.pickle')
    # help_functions.pickle_dump(args, tiling_gene_to_small_segments, 'tiling_gene_to_small_segments.pickle')


def prep_seqs(args, refs, refs_lengths):
    if args.index:
        index_folder = args.index
    else:
        index_folder = args.outfolder

    parts_to_segments = help_functions.pickle_load( os.path.join(index_folder, 'parts_to_segments.pickle') )
    segment_id_to_choordinates = help_functions.pickle_load( os.path.join(index_folder, 'segment_id_to_choordinates.pickle') )
    segment_to_ref = help_functions.pickle_load( os.path.join(index_folder, 'segment_to_ref.pickle') )
    flank_choordinates = help_functions.pickle_load( os.path.join(index_folder, 'flank_choordinates.pickle') )
    exon_choordinates_to_id = help_functions.pickle_load( os.path.join(index_folder, 'exon_choordinates_to_id.pickle') )
    chr_to_id = help_functions.pickle_load( os.path.join(index_folder, 'chr_to_id.pickle') )
    id_to_chr = help_functions.pickle_load( os.path.join(index_folder, 'id_to_chr.pickle') )

    # for chr_id in id_to_chr:
    #     print(chr_id, id_to_chr[chr_id])

    # tiling_parts_to_segments = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_parts_to_segments.pickle') )
    # tiling_segment_id_to_choordinates = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_segment_id_to_choordinates.pickle') )
    # tiling_segment_to_ref = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_segment_to_ref.pickle') )
    
    print( "Number of ref seqs in gff:", len(parts_to_segments.keys()))

    refs_id = {}

    not_in_annot = set()
    for acc, seq in refs.items():
        if acc not in chr_to_id:
            not_in_annot.add(acc)
        else:
            acc_id = chr_to_id[acc]
            refs_id[acc_id] = seq

    refs_id_lengths = { acc_id : len(seq) for acc_id, seq in refs_id.items()} 
    help_functions.pickle_dump(index_folder, refs_id_lengths, 'refs_id_lengths.pickle')
    help_functions.pickle_dump(index_folder, refs_lengths, 'refs_lengths.pickle')

    print( "Number of ref seqs in fasta:", len(refs.keys()))

    not_in_ref = set(chr_to_id.keys()) - set(refs.keys())
    if not_in_ref:
        print("Warning: Detected {0} sequences that are in annotation but not in reference fasta. Using only sequences present in fasta. The following sequences cannot be detected in reference fasta:\n".format(len(not_in_ref)))
        for s in not_in_ref:
            print(s)

    if not_in_annot:
        print("Warning: Detected {0} sequences in reference fasta that are not in annotation:\n".format(len(not_in_annot)))
        for s in not_in_annot:
            print(s, "with length:{0}".format(len(refs[s])))
    # ref_part_sequences, ref_flank_sequences = augmented_gene.get_part_sequences_from_choordinates(parts_to_segments, flank_choordinates, refs_id)
    ref_part_sequences = augmented_gene.get_sequences_from_choordinates(parts_to_segments, refs_id)
    ref_flank_sequences = augmented_gene.get_sequences_from_choordinates(flank_choordinates, refs_id)

    if not args.use_NAM_seeds: # not using NAM seeds
        augmented_gene.mask_abundant_kmers(ref_part_sequences, args.min_mem, args.mask_threshold)
        augmented_gene.mask_abundant_kmers(ref_flank_sequences, args.min_mem, args.mask_threshold)

    # print([unpack('LLL',t) for t in ref_flank_sequences.keys()])
    ref_part_sequences = help_functions.update_nested(ref_part_sequences, ref_flank_sequences)
    ref_segment_sequences = augmented_gene.get_sequences_from_choordinates(segment_id_to_choordinates, refs_id)
    # ref_flank_sequences = augmented_gene.get_sequences_from_choordinates(flank_choordinates, refs_id)
    ref_exon_sequences = augmented_gene.get_sequences_from_choordinates(exon_choordinates_to_id, refs_id)
    help_functions.pickle_dump(index_folder, segment_id_to_choordinates, 'segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(index_folder, ref_part_sequences, 'ref_part_sequences.pickle')
    help_functions.pickle_dump(index_folder, ref_segment_sequences, 'ref_segment_sequences.pickle')
    help_functions.pickle_dump(index_folder, ref_flank_sequences, 'ref_flank_sequences.pickle')
    help_functions.pickle_dump(index_folder, ref_exon_sequences, 'ref_exon_sequences.pickle')

    # tiling_ref_segment_sequences = augmented_gene.get_sequences_from_choordinates(tiling_segment_id_to_choordinates, refs_id)
    # help_functions.pickle_dump(args, tiling_ref_segment_sequences, 'tiling_ref_segment_sequences.pickle')



def batch(dictionary, size, batch_type):
    # if batch_type == 'nt':
    #     total_nt = sum([len(seq) for seq in dictionary.values() ])
    batches = []
    sub_dict = {}
    curr_nt_count = 0
    for i, (acc, seq) in enumerate(dictionary.items()):
        curr_nt_count += len(seq)
        if curr_nt_count >= size:
            sub_dict[acc] = seq
            batches.append(sub_dict)
            sub_dict = {}
            curr_nt_count = 0
        else:
            sub_dict[acc] = seq

    if curr_nt_count/size != 0:
        sub_dict[acc] = seq
        batches.append(sub_dict)
    
    return batches


def strobemap_batching(reads, STROBEMAP_BATCH_SIZE, nr_cores):
    batches = []
    for j in range(nr_cores):
        batches.append({})

    # read_idx_to_thread_map = {}
    # for i in range(len(reads)):
    #     if i % STROBEMAP_BATCH_SIZE == 0:
    #         if i > 0: # remove already parsed reads
    #             remaining_reads -= min(STROBEMAP_BATCH_SIZE, remaining_reads)
    #     for j in range(remaining_reads):

    remaining_reads = len(reads)
    for i, (acc, seq) in enumerate(reads.items()):
        if i % STROBEMAP_BATCH_SIZE == 0:
            if i > 0: # remove already parsed reads
                remaining_reads -= min(STROBEMAP_BATCH_SIZE, remaining_reads)
            min_strobemap_thread_split_size = min(STROBEMAP_BATCH_SIZE, remaining_reads) // nr_cores
            max_strobemap_thread_split_size = math.ceil((min(STROBEMAP_BATCH_SIZE, remaining_reads) / nr_cores))
            idx_cutoff = min(STROBEMAP_BATCH_SIZE, remaining_reads) % nr_cores
            # print(min_strobemap_thread_split_size,max_strobemap_thread_split_size)

        idx_in_curr_batch = (i % STROBEMAP_BATCH_SIZE)
        if idx_in_curr_batch < idx_cutoff * max_strobemap_thread_split_size:
            thread_id = idx_in_curr_batch // max_strobemap_thread_split_size
        else:
            tmp = idx_in_curr_batch - idx_cutoff * max_strobemap_thread_split_size
            thread_id = idx_cutoff + tmp // min_strobemap_thread_split_size

        # print(acc, thread_id, idx_in_curr_batch)
        batches[thread_id][acc] = seq
    return batches  


def check_alignment_fit(aln_ultra, aln_other):
    """
        returns: 
        1. the differnce in scoring is positive if uLTRA better
        2. the classification obtained by uLTRA
    """
    diffs = {1,2,8} # cigar IDs for INS, DEL, SUBS
    matches_ultra = sum([length for type_, length in aln_ultra.cigartuples if type_ == 7])
    diffs_ultra = sum([length for type_, length in aln_ultra.cigartuples if type_ in diffs]) 
    matches_other = sum([length for type_, length in aln_other.cigartuples if type_ == 7])
    diffs_other = sum([length for type_, length in aln_other.cigartuples if type_ in diffs]) 
    # print(matches_ultra, diffs_ultra, matches_other, diffs_other, matches_other - diffs_other <= matches_ultra - diffs_ultra)
    # return matches_other - diffs_other <= matches_ultra - diffs_ultra
    return (matches_ultra - diffs_ultra) - (matches_other - diffs_other), aln_ultra.get_tag('XC')


def output_final_alignments( ultra_alignments_path, path_indexed_aligned, path_unindexed_aligned):
    # read in all reads from alternative aligner that was also mapped by uLTRA
    
    alt_alignments_file = pysam.AlignmentFile(path_indexed_aligned, "r", check_sq=False)
    alt_alignments = { read.query_name : read for read in alt_alignments_file.fetch(until_eof=True) if not read.is_secondary }


    alignment_infile = pysam.AlignmentFile( ultra_alignments_path, "r" )
    tmp_merged_outfile = pysam.AlignmentFile( ultra_alignments_path.decode()+ 'tmp', "w", template= alignment_infile)
    replaced_unaligned_cntr = 0
    tot_counter = 0
    scoring_dict = defaultdict(int)
    for read in alignment_infile.fetch():
        if not read.is_secondary:
            tot_counter += 1

        if read.query_name in alt_alignments:
            if read.flag == 4:
               read = alt_alignments[ read.query_name ] # replace unmapped uLTRA read with alternative alignment if mapped
               replaced_unaligned_cntr += 1
            elif not read.is_secondary: 
                ultra_scoring_diff, classification = check_alignment_fit(read,  alt_alignments[ read.query_name ])
                scoring_dict[ultra_scoring_diff] += 1
                if ultra_scoring_diff < 0:
                    read = alt_alignments[ read.query_name ] # replace uLTRA read with alternative alignment if better fit

        tmp_merged_outfile.write(read)
    alignment_infile.close()
    # path_genomic_aligned = os.path.join(args.outfolder, "unindexed.sam")

    # add all reads that we did not attempt to align with uLTRA
    # these reads had a primary alignment to unindexed regions by other pre-processing aligner (minimap2 as of now)
    not_attempted_cntr = 0
    unindexed = pysam.AlignmentFile(path_unindexed_aligned, "r")
    for read in unindexed.fetch():
        tmp_merged_outfile.write(read)
        if not read.is_secondary: 
            not_attempted_cntr += 1
    unindexed.close()
    tmp_merged_outfile.close()
    print("{0} reads were not attempted to be aligned with ultra, instead alternative aligner was used.".format(not_attempted_cntr))
    print("{0} reads with primary alignments were replaced with alternative aligner because they were unaligned with uLTRA.".format(replaced_unaligned_cntr))
    print("{0} primary alignments had best fit with uLTRA.".format(sum([v for k,v in scoring_dict.items() if k > 0])))
    print("{0} primary alignments had equal fit.".format(scoring_dict[0]))
    print("{0} primary alignments had best fit with alternative aligner.".format(sum([v for k,v in scoring_dict.items() if k < 0])))

    bin_boundaries = [-2**32, -100,-50,-20,-10,-5,-4,-3,-2,-1, 0, 1, 2, 3, 4, 5, 10, 20, 50, 100, 2**32]
    n = len(bin_boundaries)
    counts = [0]*n #{ (b_l, b_u) : 0 for b_l, b_u in zip(bin_boundaries[:-1], bin_boundaries[1:])}
    start_next = 0
    for k in sorted(scoring_dict.keys()):
        for i in range(start_next, n):
            b = bin_boundaries[i]
        # for i, b in enumerate(bin_boundaries):
            if k < b:
                counts[i] += scoring_dict[k]
            else:
                start_next = i


    print("Table of score-difference between alignment methods (negative number: alternative aligner better fit, positive number is uLTRA better fit)")
    print("Score is calculated as sum(identities) - sum(ins, del, subs)")
    print("Format: Score difference: Number of primary alignments ")
    for i in range(len(counts)-1):
        print("[{0} - {1}): {2}".format(bin_boundaries[i],bin_boundaries[i+1], counts[i+1] - counts[i]))
    # print("{0} read with primary alignments aligned with uLTRA.".format(tot_counter - replaced_unaligned_cntr - replaced_fit_cntr))

    shutil.move(ultra_alignments_path.decode()+ 'tmp', ultra_alignments_path)


def align_reads(args):
    if args.nr_cores > 1:
        mp.set_start_method('spawn')
        print(mp.get_context())
        print("Environment set:", mp.get_context())
        print("Using {0} cores.".format(args.nr_cores))

    if args.index:
        if os.path.isdir(args.index):
            index_folder = args.index
        else:
            print("The index folder specified for alignment is not found. You specified: ", args.index )
            print("Build  the index to this folder, or specify another forder where the index has been built." )
            sys.exit()
    else:
        index_folder = args.outfolder

    # topological_sorts = help_functions.pickle_load( os.path.join(args.outfolder, 'top_sorts.pickle') )
    # path_covers = help_functions.pickle_load( os.path.join(args.outfolder, 'paths.pickle') )

    ref_part_sequences = help_functions.pickle_load( os.path.join(index_folder, 'ref_part_sequences.pickle') )
    refs_id_lengths = help_functions.pickle_load( os.path.join(index_folder, 'refs_id_lengths.pickle') )
    refs_lengths = help_functions.pickle_load( os.path.join(index_folder, 'refs_lengths.pickle') )

    if not args.disable_mm2:
        print("Filtering reads aligned to unindexed regions with minimap2 ")
        nr_reads_to_ignore, path_reads_to_align = prefilter_genomic_reads.main(ref_part_sequences, args.ref, args.reads, args.outfolder, index_folder, args.nr_cores, args.genomic_frac, args.mm2_ksize)
        args.reads = path_reads_to_align
        print("Done filtering. Reads filtered:{0}".format(nr_reads_to_ignore))

    # print(ref_part_sequences)
    ref_path = os.path.join(args.outfolder, "refs_sequences.fa")
    refs_file = open(ref_path, 'w') #open(os.path.join(outfolder, "refs_sequences_tmp.fa"), "w")
    for sequence_id, seq  in ref_part_sequences.items():
        chr_id, start, stop = unpack('LLL',sequence_id)
        # for (start,stop), seq  in ref_part_sequences[chr_id].items():
        refs_file.write(">{0}\n{1}\n".format(str(chr_id) + str("^") + str(start) + "^" + str(stop), seq))
    refs_file.close()

    del ref_part_sequences

    ######### FIND MEMS WITH MUMMER #############
    #############################################
    #############################################

    mummer_start = time()
    if args.use_NAM_seeds:
        print("Processing reads for MEM finding")
        reads_tmp = open(os.path.join(args.outfolder, 'reads_tmp.fq'), 'w')
        for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
            # print(seq)
            # print(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA))
            reads_tmp.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
        reads_tmp.close()
        args.reads_tmp = reads_tmp.name
        mem_wrapper.find_nams_strobemap(args.outfolder, args.reads_tmp, ref_path, args.outfolder, args.nr_cores, args.min_mem)
        print("Time for StrobeMap to find NAMs:{0} seconds.".format(time()-mummer_start))
    else: # Use slaMEM
        if args.nr_cores == 1:
            print("Processing reads for MEM finding")
            reads_tmp = open(os.path.join(args.outfolder, 'reads_tmp.fq'), 'w')
            for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
                # print(seq)
                # print(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA))
                reads_tmp.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
            reads_tmp.close()
            args.reads_tmp = reads_tmp.name
            print("Finished processing reads for MEM finding ")

            mummer_out_path =  os.path.join( args.outfolder, "seeds_batch_-1.txt" )
            print("Running MEM finding forward") 
            mem_wrapper.find_mems_slamem(args.outfolder, args.reads_tmp, ref_path, mummer_out_path, args.min_mem)
            print("Completed MEM finding forward")

            print("Processing reverse complement reads for MEM finding")
            reads_rc = open(os.path.join(args.outfolder, 'reads_rc.fq'), 'w')
            for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
                # print(help_functions.reverse_complement(seq))
                # print(help_functions.remove_read_polyA_ends(help_functions.reverse_complement(seq), args.reduce_read_ployA))
                reads_rc.write('>{0}\n{1}\n'.format(acc, help_functions.reverse_complement(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5))))
            reads_rc.close()
            args.reads_rc = reads_rc.name
            print("Finished processing reverse complement reads for MEM finding")

            mummer_out_path =  os.path.join(args.outfolder, "seeds_batch_-1_rc.txt" )
            print("Running MEM finding reverse")
            mem_wrapper.find_mems_slamem(args.outfolder, args.reads_rc, ref_path, mummer_out_path, args.min_mem)
            print("Completed MEM finding reverse")
        
        else: # multiprocess with slaMEM
            reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            total_nt = sum([len(seq) for seq in reads.values() ])
            batch_size = int(total_nt/int(args.nr_cores) + 1)
            print("batch nt:", batch_size, "total_nt:", total_nt)
            read_batches = batch(reads, batch_size, 'nt')
            
            #### TMP remove not to call mummer repeatedly when bugfixing #### 
            
            batch_args = []
            for i, read_batch_dict in enumerate(read_batches):
                print(len(read_batch_dict))
                read_batch_temp_file = open(os.path.join(args.outfolder, "reads_batch_{0}.fa".format(i)), "w")
                read_batch_temp_file_rc = open(os.path.join(args.outfolder, "reads_batch_{0}_rc.fa".format(i) ), "w")
                for acc, seq in read_batch_dict.items():
                    read_batch_temp_file.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
                read_batch_temp_file.close()

                for acc, seq in read_batch_dict.items():
                    read_batch_temp_file_rc.write('>{0}\n{1}\n'.format(acc, help_functions.reverse_complement(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5))))
                read_batch_temp_file_rc.close()
                
                read_batch = read_batch_temp_file.name
                read_batch_rc = read_batch_temp_file_rc.name
                mummer_batch_out_path =  os.path.join( args.outfolder, "seeds_batch_{0}.txt".format(i) )
                mummer_batch_out_path_rc =  os.path.join(args.outfolder, "seeds_batch_{0}_rc.txt".format(i) )
                batch_args.append( (args.outfolder, read_batch, ref_path, mummer_batch_out_path, args.min_mem ) )
                batch_args.append( (args.outfolder, read_batch_rc, ref_path, mummer_batch_out_path_rc, args.min_mem ) )

            pool = Pool(processes=int(args.nr_cores))
            results = pool.starmap(mem_wrapper.find_mems_slamem, batch_args)
            pool.close()
            pool.join() 
            
            ####################################################################


        print("Time for slaMEM to find mems:{0} seconds.".format(time()-mummer_start))
    #############################################
    #############################################
    #############################################


    print("Starting aligning reads.")
    if args.use_NAM_seeds:
        if args.nr_cores == 1:
            reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            classifications, alignment_outfile_name = align.align_single(reads, refs_id_lengths, args, -1)
        else:
            # OrderedDict # dicts are ordered from python v3.6 and above. 
            # One can use OrderedDict for compatibility with python v 3.4-3.5
            reads = {acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            # batch reads and mems up: divide reads by  nr_cores to get batch size
            # then write to separate SAM-files with a batch index, 
            # finally merge sam file by simple cat in terminal 
            aligning_start = time()
            # batch_size = int(len(reads)/int(args.nr_cores) + 1)
            STROBEMAP_BATCH_SIZE=500000
            read_batches = strobemap_batching(reads, STROBEMAP_BATCH_SIZE, int(args.nr_cores))
            print('Nr reads:', len(reads), "nr batches:", len(read_batches), [len(b) for b in read_batches])
            classifications, alignment_outfiles = align.align_parallel(read_batches, refs_id_lengths, args)
        
            print("Time to align reads:{0} seconds.".format(time()-aligning_start))

            # Combine samfiles produced from each batch
            combine_start = time()
            # print(refs_lengths)
            alignment_outfile = pysam.AlignmentFile( os.path.join(args.outfolder, args.prefix+".sam"), "w", reference_names=list(refs_lengths.keys()), reference_lengths=list(refs_lengths.values()) ) #, template=samfile)

            for f in alignment_outfiles:
                samfile = pysam.AlignmentFile(f, "r")
                for read in samfile.fetch():
                    alignment_outfile.write(read)
                samfile.close()

            alignment_outfile.close()
            alignment_outfile_name = alignment_outfile.filename
            print("Time to merge SAM-files:{0} seconds.".format(time() - combine_start))

    else: # Use slaMEM
        if args.nr_cores == 1:
            reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            classifications, alignment_outfile_name = align.align_single(reads, refs_id_lengths, args, -1)
        else:
            # batch reads and mems up: divide reads by  nr_cores to get batch size
            # then write to separate SAM-files with a batch index, 
            # finally merge sam file by simple cat in terminal 
            aligning_start = time()
            batch_size = int(len(reads)/int(args.nr_cores) + 1)
            # read_batches = batch(reads, batch_size)
            print('Nr reads:', len(reads), "nr batches:", len(read_batches), [len(b) for b in read_batches])
            classifications, alignment_outfiles = align.align_parallel(read_batches, refs_id_lengths, args)
        
            print("Time to align reads:{0} seconds.".format(time()-aligning_start))

            # Combine samfiles produced from each batch
            combine_start = time()
            # print(refs_lengths)
            alignment_outfile = pysam.AlignmentFile( os.path.join(args.outfolder, args.prefix+".sam"), "w", reference_names=list(refs_lengths.keys()), reference_lengths=list(refs_lengths.values()) ) #, template=samfile)

            for f in alignment_outfiles:
                samfile = pysam.AlignmentFile(f, "r")
                for read in samfile.fetch():
                    alignment_outfile.write(read)
                samfile.close()

            alignment_outfile.close()
            alignment_outfile_name = alignment_outfile.filename
            print("Time to merge SAM-files:{0} seconds.".format(time() - combine_start))



    # need to merge genomic/unindexed alignments with the uLTRA-aligned alignments
    if not args.disable_mm2:
        path_indexed_aligned = os.path.join(args.outfolder, "indexed.sam")
        path_unindexed_aligned = os.path.join(args.outfolder, "unindexed.sam")
        output_final_alignments(alignment_outfile_name, path_indexed_aligned, path_unindexed_aligned)

    counts = defaultdict(int)
    alignment_coverage = 0
    for read_acc in reads:
        if read_acc not in classifications:
            # print(read_acc, "did not meet the threshold")
            pass
        elif classifications[read_acc][0] != 'FSM':
            # print(read_acc, classifications[read_acc]) 
            pass
        if read_acc in classifications:
            alignment_coverage += classifications[read_acc][1]
            if classifications[read_acc][1] < 1.0:
                # print(read_acc, 'alignemnt coverage:', classifications[read_acc][1])
                pass
            counts[classifications[read_acc][0]] += 1
        else:
            counts['unaligned'] += 1


    print(counts)
    print("total alignment coverage:", alignment_coverage)

    if not args.keep_temporary_files:
        print("Deleting temporary files...")
        seeds = glob.glob(os.path.join(args.outfolder, "seeds_*"))
        mum = glob.glob(os.path.join(args.outfolder, "mummer*"))
        sla = glob.glob(os.path.join(args.outfolder, "slamem*"))
        reads_tmp = glob.glob(os.path.join(args.outfolder, "reads_batch*"))
        minimap_tmp = glob.glob(os.path.join(args.outfolder, "minimap2*"))
        ultra_tmp = glob.glob(os.path.join(args.outfolder, "uLTRA_batch*"))
        
        f1 = os.path.join(args.outfolder, "reads_after_genomic_filtering.fasta")
        f2 = os.path.join(args.outfolder, "indexed.sam")
        f3 = os.path.join(args.outfolder, "unindexed.sam")
        f4 = os.path.join(args.outfolder, "refs_sequences.fa")
        f5 = os.path.join(args.outfolder, "refs_sequences.fa")
        f6 = os.path.join(args.outfolder, "reads_rc.fq")
        f7 = os.path.join(args.outfolder, "reads_tmp.fq")
        misc_files = [f1,f2,f3,f4,f5,f6,f7]
        for f in seeds + mum + sla + reads_tmp + minimap_tmp + ultra_tmp+ misc_files:
            if os.path.isfile(f):
                os.remove(f)
                print("removed:", f)
    print("Done.")



if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="uLTRA -- Align and classify long transcriptomic reads based on colinear chaining algorithms to gene regions")
    parser.add_argument('--version', action='version', version='%(prog)s 0.0.4.1')

    subparsers = parser.add_subparsers(help='Subcommands for eaither constructing a graph, or align reads')
    # parser.add_argument("-v", help='Different subcommands for eaither constructing a graph, or align reads')

    pipeline_parser = subparsers.add_parser('pipeline', help= "Perform all in one: prepare splicing database and reference sequences and align reads.")
    indexing_parser = subparsers.add_parser('index', help= "Construct data structures used for alignment.")
    align_reads_parser = subparsers.add_parser('align', help="Classify and align reads with colinear chaining to DAGs")

    pipeline_parser.add_argument('ref', type=str, help='Reference genome (fasta)')
    pipeline_parser.add_argument('gtf', type=str, help='Path to gtf file with gene models.')
    pipeline_parser.add_argument('reads', type=str, help='Path to fasta/fastq file with reads.')
    pipeline_parser.add_argument('outfolder', type=str, help='Path to output folder.')
    group = pipeline_parser.add_mutually_exclusive_group()
    group.add_argument('--ont', action="store_true", help='Set parameters suitable for ONT (Currently sets: --min_mem 17, --min_acc 0.6 --alignment_threshold 0.5).')
    group.add_argument('--isoseq', action="store_true", help='Set parameters suitable for IsoSeq (Currently sets: --min_mem 20, --min_acc 0.8 --alignment_threshold 0.5).')
    # group2 = pipeline_parser.add_mutually_exclusive_group()
    # group2.add_argument('--mummer', action="store_true", help='Use mummer to find mems. About 1.5-2x faster than slamem but consumes  >4x more memory (slamem is recommended for human and larger)')
    # group2.add_argument('--slamem', action="store_true", help='Use slaMEM to find mems. About 1.5-2x slower than mumer but consumes less than 25% of the memory compared to mummer')

    pipeline_parser.add_argument('--min_mem', type=int, default=17, help='Threshold for minimum mem size considered.')
    pipeline_parser.add_argument('--min_segm', type=int, default=25, help='Threshold for minimum segment size considered.')
    pipeline_parser.add_argument('--min_acc', type=float, default=0.5, help='Minimum accuracy of MAM to be considered in mam chaining.')
    pipeline_parser.add_argument('--flank_size', type=int, default=1000, help='Size of genomic regions surrounding genes.')
    pipeline_parser.add_argument('--max_intron', type=int, default=1200000, help='Set global maximum size between mems considered in chaining solution. This is otherwise inferred from GTF file per chromosome.')
    pipeline_parser.add_argument('--small_exon_threshold', type=int, default=200, help='Considered in MAM solution even if they cont contain MEMs.')
    pipeline_parser.add_argument('--reduce_read_ployA', type=int, default=8, help='Reduces polyA tails longer than X bases (default 10) in reads to 5bp before MEM finding. This helps MEM matching to spurios regions but does not affect final read alignments.')
    pipeline_parser.add_argument('--alignment_threshold', type=int, default=0.5, help='Lower threshold for considering an alignment. \
                                                                                        Counted as the difference between total match score and total mismatch penalty. \
                                                                                        If a read has 25%% errors (under edit distance scoring), the difference between \
                                                                                        matches and mismatches would be (very roughly) 0.75 - 0.25 = 0.5 with default alignment parameters \
                                                                                        match =2, subs=-2, gap open 3, gap ext=1. Default val (0.5) sets that a score\
                                                                                        higher than 2*0.5*read_length would be considered an alignment, otherwise unaligned.')
    pipeline_parser.add_argument('--t', dest = 'nr_cores', type=int, default=3, help='Number of cores.')
    pipeline_parser.add_argument('--index', type=str, default="", help='Path to where index files will be written to (in indexing step) and read from (in alignment step) [default is the outfolder path].')
    pipeline_parser.add_argument('--prefix', type=str, default="reads", help='Outfile prefix [default=reads]. "--prefix sample_X" will output a file sample_X.sam.')
    pipeline_parser.add_argument('--non_covered_cutoff', type=int, default=15, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    pipeline_parser.add_argument('--dropoff', type=float, default=0.95, help='Ignore alignment to hits with read coverage of this fraction less than the best hit.')
    pipeline_parser.add_argument('--max_loc', type=float, default=5, help='Limit read to be aligned to at most max_loc places (default 5).\
                                                                            This prevents time blowup for reads from highly repetitive regions (e.g. some genomic intra-priming reads)\
                                                                            but may limit all posible alignments to annotated gene families with many highly similar copies.')
    pipeline_parser.add_argument('--ignore_rc', action='store_true', help='Ignore to map to reverse complement.')
    pipeline_parser.add_argument('--disable_infer', action='store_true', help='Makes splice creation step much faster. This parameter can be set if gene and transcript name fields are provided in gtf file, which is standard for the ones provided by GENCODE and Ensemble.')
    pipeline_parser.add_argument('--mask_threshold', type=int, default=200, help='Abundance occurance threshold. Masks more abundant k-mers than this threshold before MEM finding.')
    pipeline_parser.add_argument('--disable_mm2', action='store_true', help='Disables utilizing minimap2 to detect genomic primary alignments and to quality check uLTRAs primary alignments.\
                                                                                An alignment is classified as genomic if more than --genomic_frac (default 10%%) of its aligned length is outside\
                                                                                regions indexed by uLTRA. Note that uLTRA indexes flank regions such as 3 prime, 5 prime and (parts of) introns.')
    pipeline_parser.add_argument('--genomic_frac', type=float, default=0.1, help='If parameter prefilter_genomic is set, this is the threshild for fraction of aligned portion of read that is outside uLTRA indexed regions to be considered genomic (default 0.1).')
    pipeline_parser.add_argument('--keep_temporary_files', action='store_true', help='Keeps all intermediate files used for the alignment. This parameter is manily good for bugfixing and development.')
    pipeline_parser.add_argument('--use_NAM_seeds', action='store_true', help='Uses StrobeMap to generate NAM seeds instead of MEMs. Uses StrobeMap parameters ./StrobeMap -n 2 -k 15 -v 16 -w 45 -t [nr_cores] -s.')
    pipeline_parser.set_defaults(which='pipeline')

    indexing_parser.add_argument('ref', type=str, help='Reference genome (fasta)')
    indexing_parser.add_argument('gtf', type=str, help='Path to gtf or gtf file with gene models.')
    indexing_parser.add_argument('outfolder', type=str, help='Path to output folder.')
    indexing_parser.add_argument('--min_segm', type=int, default=25, help='Threshold for minimum segment size considered.')
    indexing_parser.add_argument('--flank_size', type=int, default=1000, help='Size of genomic regions surrounding genes.')
    indexing_parser.add_argument('--small_exon_threshold', type=int, default=200, help='Considered in MAM solution even if they cont contain MEMs.')
    indexing_parser.add_argument('--disable_infer', action='store_true', help='Makes splice creation step much faster. Thes parameter can be set if gene and transcript name fields are provided in gtf file.')
    indexing_parser.add_argument('--min_mem', type=int, default=17, help='Threshold for minimum mem size considered.')
    indexing_parser.add_argument('--mask_threshold', type=int, default=200, help='Abundance occurance threshold. Masks more abundant k-mers than this threshold before MEM finding.')
    indexing_parser.add_argument('--use_NAM_seeds', action='store_true', help='Activate this is you plan to align with parameter --use_NAM_seeds. Will inactivate masking in the indexing step.')

    indexing_parser.set_defaults(which='index')


    align_reads_parser.add_argument('ref', type=str, help='Reference genome (fasta).')    
    align_reads_parser.add_argument('reads', type=str, help='Path to fasta/fastq file with reads.')
    align_reads_parser.add_argument('outfolder', type=str, help='Path to output folder.')   
    align_reads_parser.add_argument('--t', dest = 'nr_cores', type=int, default=3, help='Number of cores.')
    align_reads_parser.add_argument('--index', type=str, default="", help='Path to where index files will be read from [default is the outfolder path].')
    align_reads_parser.add_argument('--prefix', default="reads", type=str, help='Outfile prefix [default=reads]. "--prefix sample_X" will output a file sample_X.sam.')
    align_reads_parser.add_argument('--max_intron', type=int, default=1200000, help='Set global maximum size between mems considered in chaining solution. This is otherwise inferred from GTF file per chromosome.')
    align_reads_parser.add_argument('--reduce_read_ployA', type=int, default=8, help='Reduces polyA tails longer than X bases (default 10) in reads to 5bp before MEM finding. This helps MEM matching to spurios regions but does not affect final read alignments.')
    align_reads_parser.add_argument('--alignment_threshold', type=int, default=0.5, help='Lower threshold for considering an alignment. \
                                                                                        Counted as the difference between total match score and total mismatch penalty. \
                                                                                        If a read has 25%% errors (under edit distance scoring), the difference between \
                                                                                        matches and mismatches would be (very roughly) 0.75 - 0.25 = 0.5 with default alignment parameters \
                                                                                        match =2, subs=-2, gap open 3, gap ext=1. Default val (0.5) sets that a score\
                                                                                        higher than 2*0.5*read_length would be considered an alignment, otherwise unaligned.')
    align_reads_parser.add_argument('--non_covered_cutoff', type=int, default=15, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    align_reads_parser.add_argument('--dropoff', type=float, default=0.95, help='Ignore alignment to hits with read coverage of this fraction less than the best hit.')
    align_reads_parser.add_argument('--max_loc', type=float, default=5, help='Limit read to be aligned to at most max_loc places (default 5).\
                                                                            This prevents time blowup for reads from highly repetitive regions (e.g. some genomic intra-priming reads)\
                                                                            but may limit all posible alignments to annotated gene families with many highly similar copies.')

    align_reads_parser.add_argument('--ignore_rc', action='store_true', help='Ignore to map to reverse complement.')
    align_reads_parser.add_argument('--min_mem', type=int, default=17, help='Threshold for minimum mem size considered.')
    align_reads_parser.add_argument('--min_acc', type=float, default=0.5, help='Minimum accuracy of MAM to be considered in mam chaining.')

    align_reads_parser.add_argument('--disable_mm2', action='store_true', help='Disables utilizing minimap2 to detect genomic primary alignments and to quality check uLTRAs primary alignments.\
                                                                                An alignment is classified as genomic if more than --genomic_frac (default 10%%) of its aligned length is outside\
                                                                                regions indexed by uLTRA. Note that uLTRA indexes flank regions such as 3 prime, 5 prime and (parts of) introns.')

    align_reads_parser.add_argument('--genomic_frac', type=float, default=0.1, help='If parameter prefilter_genomic is set, this is the threshild for fraction of aligned portion of read that is outside uLTRA indexed regions to be considered genomic (default 0.1).')
    align_reads_parser.add_argument('--keep_temporary_files', action='store_true', help='Keeps all intermediate files used for the alignment. This parameter is manily good for bugfixing and development.')
    align_reads_parser.add_argument('--use_NAM_seeds', action='store_true', help='Uses StrobeMap to generate NAM seeds instead of MEMs. Uses StrobeMap parameters ./StrobeMap -n 2 -k 15 -v 16 -w 45 -t [nr_cores] -s.')


    group2 = align_reads_parser.add_mutually_exclusive_group()
    group2.add_argument('--ont', action="store_true", help='Set parameters suitable for ONT (Currently sets: --min_mem 17, --min_acc 0.6 --alignment_threshold 0.5).')
    group2.add_argument('--isoseq', action="store_true", help='Set parameters suitable for IsoSeq (Currently sets: --min_mem 20, --min_acc 0.8 --alignment_threshold 0.5).')


    align_reads_parser.set_defaults(which='align_reads')

    args = parser.parse_args()
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit()

    help_functions.mkdir_p(args.outfolder)
    if len(sys.argv)==1:
        parser.print_help()
        sys.exit()


    if args.which == 'align_reads' or args.which == 'pipeline':
        args.mm2_ksize = 15
        if args.ont:
            args.min_mem = 17
            args.min_acc = 0.6
            args.mm2_ksize = 14
            # args.alignment_threshold = 0.5
        if args.isoseq:
            args.min_mem = 20
            args.min_acc = 0.8
            # args.alignment_threshold = 0.5


    if args.which == 'index':
        args.index = args.outfolder
        refs, refs_lengths = load_reference(args)
        prep_splicing(args, refs_lengths)
        prep_seqs(args, refs, refs_lengths)
    elif args.which == 'align_reads':
        align_reads(args)
    elif args.which == 'pipeline':
        refs, refs_lengths = load_reference(args)
        prep_splicing(args, refs_lengths)
        prep_seqs(args, refs, refs_lengths)
        align_reads(args)        
    else:
        print('invalid call')
