#!/usr/bin/env python3
#
# Sniffles2
# A fast structural variant caller for long-read sequencing data
#
# Created: 28.05.2021
# Author:  Moritz Smolka
# Contact: moritz.g.smolka@gmail.com
#

DEV_MONITOR_MEM=False

import sys
if not sys.version_info>=(3,7):
    print(f"Error: Sniffles2 must be run with Python version 3.7 or above (detected Python version: {sys.version_info.major}.{sys.version_info.minor}). Exiting")
    exit(1)

import multiprocessing
import collections
import math
import time
import os
import json

import pysam

from dataclasses import dataclass

from sniffles import config
from sniffles import vcf
from sniffles import snf
from sniffles import parallel
from sniffles import util

#TODO: Dev/Debugging only - Remove for prod
"""
if DEV_MONITOR_MEM:
    import psutil
    def dbg_get_total_memory_usage_MB():
        total=0
        n=0
        proc=psutil.Process(os.getpid())
        for child in proc.children(recursive=True):
            total+=child.memory_info().rss
            n+=1
        total+=proc.memory_info().rss
        return total/(1000.0*1000.0)
"""
#END:TODO

def Sniffles2_Main(config,processes):
    #
    # Determine Sniffles2 run mode
    #

    input_ext=[f.split(".")[-1].lower() for f in config.input]

    if len(set(input_ext)) > 1:
        util.fatal_error_main(f"Please specify either: A single .bam/.cram file - OR - one or more .snf files - OR - a single .tsv file containing a list of .snf files and optional sample ids as input. (supplied were: {list(set(input_ext))})")

    if "bam" in input_ext or "cram" in input_ext:
        if input_ext.count("bam")+input_ext.count("cram") > 1:
            util.fatal_error_main(f"Please specifify max 1 .bam//.cram file as input (got {input_ext.count('bam')})")
        config.input=config.input[0]

        if config.genotype_vcf != None:
            config.mode="genotype_vcf"
        else:
            config.mode="call_sample"

        config.input_is_cram=False
        if "bam" in input_ext:
            config.input_mode="rb"
        elif "cram" in input_ext:
            config.input_mode="rc"
            config.input_is_cram=True

    elif "snf" in input_ext or "tsv" in input_ext:
        config.mode="combine"
    else:
        util.fatal_error_main(f"Failed to determine run mode from input. Please specify either: A single .bam file - OR - one or more .snf files - OR - a single .tsv file containing a list of .snf files and optional sample ids as input. (supplied were: {list(set(input_ext))})")

    if config.mode != "call_sample" and config.snf != None:
        util.fatal_error_main(f"--snf cannot be used with run mode {mode}")

    if config.vcf == None and config.snf == None:
        util.fatal_error_main("Please specify at least one of: --vcf or --snf for output (both may be used at the same time)")

    if config.mode=="call_sample":
        if config.sample_id==None:
            #config.sample_id,_=os.path.splitext(os.path.basename(config.input))
            config.sample_ids_vcf=[(0,"SAMPLE")]
        else:
            config.sample_ids_vcf=[(0,config.sample_id)]

    elif config.mode=="combine":
        config.sample_id=None

        if config.combine_consensus:
            config.sample_ids_vcf=[(0,"CONSENSUS")]
        else:
            config.sample_ids_vcf=[] #Determined from .snf headers

    print(f"Running {config.version}, build {config.build}")
    print(f"  Run Mode: {config.mode}")
    print(f"  Start on: {config.start_date}")
    print(f"  Working dir: {config.workdir}")
    print(f"  Used command: {config.command}")
    print("==============================")

    #
    # call_sample/genotype_vcf: Open .bam file for single calling / .snf creation
    #
    contig_tandem_repeats={}
    if config.mode=="call_sample" or config.mode=="genotype_vcf":
        print(f"Opening for reading: {config.input}")
        bam_in=pysam.AlignmentFile(config.input, config.input_mode)
        try:
            has_index=bam_in.check_index()
            if not has_index:
                raise ValueError
        except ValueError:
            util.fatal_error_main(f"Unable to load index for input file '{config.input}'. Please verify that your input file is sorted + indexed and that the index .bai file is valid and in the right location.")

        #
        #Load tandem repeat annotations
        #
        if config.tandem_repeats!=None:
            contig_tandem_repeats=util.load_tandem_repeats(config.tandem_repeats,config.tandem_repeat_region_pad)
            print(f"Opening for reading: {config.tandem_repeats} (tandem repeat annotations for {len(contig_tandem_repeats)} contigs)")

    #
    #genotype_vcf: Read SVs from VCF to be genotyped
    #
    if config.mode=="genotype_vcf":
        path,ext=os.path.splitext(config.genotype_vcf)
        ext=ext.lower()
        if ext==".gz":
            vcf_in_handle=pysam.BGZFile(config.genotype_vcf,"rb")
        elif ext==".vcf":
            vcf_in_handle=open(config.genotype_vcf,"r")
        else:
            util.fatal_error_main("Expected a .vcf or .vcf.gz file for genotyping using --genotype-vcf")
        vcf_in=vcf.VCF(config,vcf_in_handle)

        genotype_lineindex_order=[]
        genotype_lineindex_svs={}
        genotype_contig_svs={}
        genotype_lineindex_svcalls_returned={}
        for svcall in vcf_in.read_svs_iter():
            if not svcall.contig in genotype_contig_svs:
                genotype_contig_svs[svcall.contig]=[]
            assert(not svcall.raw_vcf_line_index in genotype_lineindex_svs)
            genotype_lineindex_order.append(svcall.raw_vcf_line_index)
            genotype_lineindex_svs[svcall.raw_vcf_line_index]=svcall
            genotype_contig_svs[svcall.contig].append(svcall)
        print(f"Opening for reading: {config.genotype_vcf} (read {len(genotype_lineindex_svs)} SVs to be genotyped)")

    #
    # Open output files
    #
    config.vcf_output_bgz=False
    if config.vcf != None:
        path,ext=os.path.splitext(config.vcf)
        if ext==".gz" or ext==".bgz":
            config.vcf_output_bgz=True

        vcf_output_info=[]
        if config.mode=="combine":
            vcf_output_info.append("multi-sample")
        else:
            vcf_output_info.append("single-sample")
        if config.sort:
            vcf_output_info.append("sorted")
        if config.vcf_output_bgz:
            vcf_output_info.append("bgzipped")
            vcf_output_info.append("tabix-indexed")

        if len(vcf_output_info)==0:
            vcf_output_info_str=""
        else:
            vcf_output_info_str=f"({', '.join(vcf_output_info)})"


        if config.vcf_output_bgz:
            if not config.sort:
                util.fatal_error_main(".gz (bgzip) output is only supported with sorting enabled")
            vcf_handle=pysam.BGZFile(config.vcf,"w")
        else:
            vcf_handle=open(config.vcf,"w")

        vcf_out=vcf.VCF(config,vcf_handle)

        if config.mode=="call_sample" or config.mode=="combine":
            if config.reference!=None:
                print(f"Opening for reading: {config.reference}")
            vcf_out.open_reference()

        print(f"Opening for writing: {config.vcf} {vcf_output_info_str}")

    if config.snf != None:
        print(f"Opening for writing: {config.snf}")
        snf_out=snf.SNFile(config,open(config.snf,"wb"))
        snf_parts=[]

    #
    # Plan multiprocessing tasks
    #
    task_id=0
    tasks_list=[]
    contig_tasks_intervals={}

    if config.mode=="call_sample" or config.mode=="genotype_vcf":
        #
        # Process .bam header
        #
        total_mapped=bam_in.mapped
        if (config.threads==1 and not config.low_memory) or config.task_count_multiplier==0:
            task_max_reads=total_mapped
        else:
            task_max_reads=max(1,math.floor(total_mapped/(config.threads*config.task_count_multiplier)))

        if total_mapped==0:
            #Total mapped returns 0 for CRAM files
            config.task_read_id_offset_mult=10**9
        else:
            #BAM file
            config.task_read_id_offset_mult=10**math.ceil(math.log(total_mapped)+1)

        contig_lengths=[]
        contig_coverages={}
        contigs_with_tr_annotations=0
        for contig in bam_in.get_index_statistics():
            if task_max_reads==0:
                task_count=1
            else:
                task_count=max(1,math.ceil(contig.mapped / float(task_max_reads)))
            contig_str=str(contig.contig)
            contig_length=bam_in.get_reference_length(contig_str)
            contig_lengths.append((contig_str, contig_length))
            contig_coverages[contig_str]=[]
            task_length=math.floor(contig_length/float(task_count))
            contigs_with_tr_annotations+=int(contig_str in contig_tandem_repeats)
            startpos=0

            while startpos < contig_length-1:
                endpos=min(contig_length-1,startpos+task_length)
                if config.genotype_vcf!=None:
                    if contig_str in genotype_contig_svs:
                        genotype_svs=[target_sv for target_sv in genotype_contig_svs[contig_str] if target_sv.pos >= startpos and target_sv.pos < endpos]
                    else:
                        genotype_svs=[]
                else:
                    genotype_svs=None

                task=parallel.Task(id=task_id,
                                   contig=contig_str,
                                   start=startpos,
                                   end=endpos,
                                   assigned_process_id=None,
                                   tandem_repeats=contig_tandem_repeats[contig_str] if contig_str in contig_tandem_repeats else None,
                                   genotype_svs=genotype_svs,
                                   sv_id=0)
                tasks_list.append(task)
                if not contig_str in contig_tasks_intervals:
                    contig_tasks_intervals[contig_str]=[]
                contig_tasks_intervals[contig_str].append((task.start,task.end,task))
                startpos+=task_length
                task_id+=1
        config.contig_lengths=contig_lengths

        if contigs_with_tr_annotations < len(contig_lengths) and config.tandem_repeats!=None:
            print(f"Info: {contigs_with_tr_annotations} of {len(contig_lengths)} contigs in the input sample have associated tandem repeat annotations.")

            if contigs_with_tr_annotations==0:
                util.fatal_error_main("A tandem repeat annotations file was provided, but no matching annotations were found for any contig in the sample input file. Please check if the contig naming scheme in the tandem repeat annotations matches with the one in the input sample file.")

    elif config.mode=="combine":
        #
        # Process .snf headers
        #
        config.snf_input_info=[]
        total_mapped=0
        snf_internal_id=0

        input_snfs_sample_ids=[]

        if len(config.input)==1 and input_ext[0]=="tsv":
            print(f"Opening for reading: {config.input[0]} (loading list of .snf files and associated sample ids)")
            with open(config.input[0],"r") as tsv_handle:
                for line_index, line in enumerate(tsv_handle.readlines()):
                    line_strip=line.strip()
                    if len(line_strip)==0 or line_strip[0]=="#":
                        continue
                    parts=line_strip.split("\t")
                    if len(parts)==1:
                        snf_filename=parts[0]
                        sample_id=None
                    elif len(parts)==2:
                        snf_filename=parts[0]
                        sample_id=parts[1]
                    else:
                        util.fatal_error_main(f"Invalid sample list .tsv : {config.input[0]} : Line {line_index+1} - expected either one or two columns (first column: .snf filename, second column: optional sample id to overrule the one specified in the .snf file)")
                    input_snfs_sample_ids.append((snf_filename,sample_id))
        elif input_ext[0]=="snf":
            input_snfs_sample_ids=[(item,None) for item in config.input]
        else:
            util.fatal_error_main("Failed to determine .snf files to be combined. Please specify either one or more .snf files OR a single .tsv file as input for multi-calling.")

        for input_filename, sample_id in input_snfs_sample_ids:
            snf_in=snf.SNFile(config,open(input_filename,"rb"),filename=input_filename)
            snf_in.read_header()
            total_mapped+=snf_in.header["snf_candidate_count"]
            contig_lengths=snf_in.header["config"]["contig_lengths"]
            if not config.dev_skip_snf_validation:
                if config.snf_block_size != snf_in.header["config"]["snf_block_size"]:
                    util.fatal_error_main(f"SNF block size differs for {input_filename}")
                if config.snf_format_version != snf_in.header["config"]["snf_format_version"]:
                    util.fatal_error_main(f"SNF format version for {input_filename} is not supported")
            if sample_id==None:
                if snf_in.header["config"]["sample_id"] != None:
                    sample_id=snf_in.header["config"]["sample_id"]
                else:
                    sample_id,_=os.path.splitext(os.path.basename(input_filename))
            config.snf_input_info.append({"internal_id":snf_internal_id, "sample_id": sample_id, "filename": input_filename})
            snf_internal_id+=1
            snf_in.handle.close()

        if not config.combine_consensus:
            for info in config.snf_input_info:
                config.sample_ids_vcf.append((info["internal_id"],info["sample_id"]))

        #TODO: Assure header consistency across multiple .snfs

        for contig_str,contig_length in contig_lengths:
            task=parallel.Task(id=task_id,contig=contig_str,start=0,end=contig_length-1,assigned_process_id=None,sv_id=0)
            tasks_list.append(task)
            if not contig_str in contig_tasks_intervals:
                contig_tasks_intervals[contig_str]=[]
            contig_tasks_intervals[contig_str].append((task.start,task.end,task))
            task_id+=1

        print(f"Verified headers for {len(config.input)} .snf files.")
        print("The following samples will be processed in multi-calling:")
        for info in config.snf_input_info:
            print(f"    {info['filename']} (sample ID in output VCF='{info['sample_id']}')")

    if config.mode!="genotype_vcf" and config.vcf != None:
        vcf_out.write_header(contig_lengths)

    #
    # Start workers
    #
    for proc in range(config.threads):
        pipe_main,pipe_proc=multiprocessing.Pipe()
        proc=parallel.Process(id=proc,process=multiprocessing.Process(target=parallel.Main,args=(proc,config,pipe_proc)),pipe_main=pipe_main,externals=[])
        processes.append(proc)

    processes_free=[p for p in processes]
    processes_busy=[]
    tasks_todo=[t for t in tasks_list]

    if config.vcf != None and config.sort:
        task_id_calls={}

    for proc in processes:
        proc.process.start()

    print("")
    if config.mode=="call_sample" or config.mode=="genotype_vcf":
        if config.input_is_cram:
            #CRAM file
            print(f"Analyzing alignments... (progress display disabled for CRAM input)")
        else:
            print(f"Analyzing {total_mapped} alignments total...")
    elif config.mode=="combine":
        print(f"Calling SVs across {len(config.input)} samples ({total_mapped} candidates total)...")
    print("")

    #
    # Distribute analysis tasks to workers and collect results
    #
    analysis_start_time=time.time()
    def show_progress():
        if config.no_progress:
            return

        if processed_read_count > 1000:
            obj_title="?"
            if config.mode=="call_sample" or config.mode=="genotype_vcf":
                obj_title="alignments"
            elif config.mode=="combine":
                obj_title="candidates"
            read_progress=f"{str(processed_read_count).rjust(len(str(total_mapped)))}/{total_mapped} {obj_title} processed ({int(100*float(processed_read_count/total_mapped))}%, {int(processed_read_count/(time.time()-time_build_start))}/s); "
        else:
            read_progress="(Estimating progress); "
        if DEV_MONITOR_MEM:
            mem_str=f" [DEV: Total memory usage={dbg_get_total_memory_usage_MB():.2f}MB]"
        else:
            mem_str=""

        count_str=""
        if config.mode=="call_sample" and config.snf != None:
            count_str+=f" {snf_candidate_count} candidates."
        if (config.mode=="call_sample" and config.vcf != None) or config.mode=="combine":
            count_str+=f" {sv_count} SVs."
        print(f"{read_progress}{tasks_done_count}/{len(tasks_list)} tasks done; parallel {len(processes_busy)}/{len(processes)};{count_str} {mem_str}")

    last_tasks_todo=len(tasks_todo)
    externals_count=0
    tasks_done_count=0
    processed_read_count=0
    sv_count=0
    snf_candidate_count=0
    last_processed_read_pct=0
    time_build_start=time.time()
    seed=100
    while len(tasks_todo) > 0 or len(processes_busy) > 0:
        while len(processes_free) > 0 and len(tasks_todo) > 0:
            seed=(seed*1103515245+12345)%(2**31)
            if config.low_memory:
                task=tasks_todo.pop(seed%len(tasks_todo))
            else:
                task=tasks_todo.pop()
            proc=processes_free.pop()
            processes_busy.append(proc)
            assert(task.assigned_process_id==None)
            task.assigned_process_id=proc.id
            proc.pipe_main.send([config.mode,task])

        if total_mapped > 0:
            #Progress disabled for CRAM files due to total read count not being stored in index
            processed_read_pct=int(100*float(processed_read_count/total_mapped))
            if len(tasks_todo)!=last_tasks_todo and processed_read_pct - last_processed_read_pct >= 5:
                last_processed_read_pct=processed_read_pct
                show_progress()
                last_tasks_todo=len(tasks_todo)

        if len(processes_busy) > 0:
            for proc in processes_busy:
                if proc.pipe_main.poll(0.01)==True:
                    #Process has finished building leadtab, receive and distribute leads that need to be synced
                    command,result=proc.pipe_main.recv()
                    if command=="worker_exception":
                        util.fatal_error_main("An exception occured in a worker process (see surrounding error messages)")
                    assert(command=="return_"+config.mode)
                    processed_read_count+=result["processed_read_count"]
                    processes_busy.remove(proc)
                    processes_free.append(proc)
                    tasks_done_count+=1

                    task_id=result["task_id"]
                    task=tasks_list[task_id]
                    assert(task.id==task_id)

                    if config.mode=="call_sample" or config.mode=="combine":
                        if config.vcf != None:
                            svcalls=result["svcalls"]
                            sv_count+=len(svcalls)

                            if config.sort:
                                if not task_id in task_id_calls:
                                    task_id_calls[task_id]=[]
                                task_id_calls[task_id].extend(svcalls)
                            else:
                                for svcall in svcalls:
                                    vcf_out.write_call(svcall)

                    elif config.mode=="genotype_vcf":
                        for svcall in result["svcalls"]:
                            assert(not (svcall.raw_vcf_line_index in genotype_lineindex_svcalls_returned))
                            genotype_lineindex_svcalls_returned[svcall.raw_vcf_line_index]=svcall

                    if config.mode=="call_sample":
                        if config.snf != None and result["has_snf"]:
                            #if result["snf_candidate_count"] > 0:
                            snf_candidate_count+=result["snf_candidate_count"]
                            snf_parts.append({"task":task, "snf_filename": result["snf_filename"],"snf_index": result["snf_index"], "snf_total_length": result["snf_total_length"]})
                            contig_coverages[task.contig].append(result["coverage_average_total"])

    #
    # Store avg. coverages for .snf file
    #
    if config.mode=="call_sample" and config.snf != None:
        for contig in contig_coverages:
            contig_coverages[contig]=sum(contig_coverages[contig])/len(contig_coverages[contig]) if len(contig_coverages[contig])>0 else 0
        config.contig_coverages=contig_coverages

    if total_mapped > 0:
        show_progress()
    print(f"Took {time.time()-analysis_start_time:.2f}s.")
    print("")

    #
    # Output results
    #
    lastcall=None
    if (config.mode=="call_sample" or config.mode=="combine") and config.vcf != None and config.sort:
        #TODO: Look into calls placed outside their respective task's range
        reassigned_count=0
        ignored_count=0
        sort_start_time=time.time()
        for task in tasks_list:
            if not task.id in task_id_calls:
                continue
            for call in sorted(task_id_calls[task.id],key=lambda call: call.pos):
                if call.pos < task.start or call.pos >= task.end or call.contig != task.contig:
                    new_target_tasks=list(task for start,end,task in contig_tasks_intervals[call.contig] if call.pos >= task.start and call.pos < task.end)
                    if len(new_target_tasks)!=1:
                        print(f"WARNING: Unable to assign call at {call.contig}:{call.pos} to unambiguous task. (got {len(new_target_tasks)} intervals). SVCall={call}")
                        if len(new_target_tasks)==0:
                            continue

                    new_task_id=new_target_tasks[0].id
                    if not new_task_id in task_id_calls:
                        task_id_calls[new_task_id]=[]
                    task_id_calls[new_task_id].append(call)
                    #print(f"Migrated a call from {task} to {new_target_tasks[0].data}: {call}")
                    reassigned_count+=1

        for task in tasks_list: #Tasks are already sorted by contig/coordinate (from .bam index)
            if not task.id in task_id_calls:
                continue
            for call in sorted(task_id_calls[task.id],key=lambda call: call.pos):
                #assert(lastcall==None or lastcall.pos<call.pos or lastcall.contig!=call.contig)
                #assert(call.pos >= task.start)
                #assert(call.pos < task.end)
                #assert(call.contig==task.contig)
                if call.pos < task.start or call.pos >= task.end or call.contig != task.contig:
                    ignored_count+=1
                else:
                    vcf_out.write_call(call)

        if reassigned_count!=ignored_count:
            print(f"ERROR: {ignored_count} calls ignored, but only {reassigned_count} were reassigned to correct tasks")

    if config.mode=="call_sample" and config.snf != None:
        #
        # Combine regional temporary .snf files from workers
        #
        main_index={}
        offset=0
        parts_sorted=sorted(snf_parts,key=lambda p: p["task"].id)
        for part in parts_sorted:
            part_contig=part["task"].contig
            if not part_contig in main_index:
                main_index[part_contig]={}
            for block,(part_block_start,part_block_len) in part["snf_index"].items():
                if not block in main_index[part_contig]:
                    main_index[part_contig][block]=[]
                main_index[part_contig][block].append((part_block_start+offset,part_block_len))
            offset+=part["snf_total_length"]

        header={"config": config.__dict__, "index": main_index, "snf_candidate_count": snf_candidate_count}
        header_json=json.dumps(header, default=lambda obj: "<Unstored_Object>")+"\n"
        snf_out.handle.write(header_json.encode())

        for part in parts_sorted:
            with open(part["snf_filename"],"rb") as part_handle:
                part_data=part_handle.read()
                part_handle.close()
            snf_out.handle.write(part_data)
            os.remove(part["snf_filename"])

        snf_out.close()

    if DEV_MONITOR_MEM:
        print(f"[DEV: Total memory usage={dbg_get_total_memory_usage_MB():.2f}MB]") #TODO: Remove (prod)

    if config.mode=="genotype_vcf":
        vcf_out.rewrite_header_genotype(vcf_in.header_str)

        for lineindex in genotype_lineindex_order:
            if lineindex in genotype_lineindex_svcalls_returned:
                vcf_out.rewrite_genotype(genotype_lineindex_svcalls_returned[lineindex])

    if config.vcf!=None:
        vcf_out.close()
        if config.vcf_output_bgz:
            vcf_index_start_time=time.time()
            print(f"Generating index for {config.vcf}...")
            pysam.tabix_index(config.vcf,preset="vcf",force=True)
            print(f"Indexing VCF output took {time.time()-vcf_index_start_time:.2f}s.")

    print(f"Done.")

    if config.mode=="call_sample" and config.snf!=None:
        print(f"Wrote {snf_candidate_count} SV candidates to {config.snf} (for multi-sample calling).")

    if (config.mode=="call_sample" or config.mode=="combine") and config.vcf!=None:
        print(f"Wrote {vcf_out.call_count} called SVs to {config.vcf} {vcf_output_info_str}")

    #
    #Finalize
    #
    for proc in processes:
        proc.pipe_main.send(["finalize",None])

    for proc in processes:
        proc.process.join()


if __name__ == "__main__":
    processes=[]

    try:
        Sniffles2_Main(config.from_cmdline(),processes)
    except util.Sniffles2Exit:
        if len(processes):
            #Allow time for child process error messages to propagate
            print("Sniffles2Main: Shutting down workers")
            time.sleep(10)
        for proc in processes:
            try:
                proc.process.terminate()
            except:
                pass

        for proc in processes:
            try:
                proc.process.join()
            except:
                pass
