#!/usr/bin/env python
from __future__ import division, print_function
import argparse
import numpy as np
import igraph as ig
import leidenalg
import sys


#based on https://github.com/theislab/scanpy/blob/8131b05b7a8729eae3d3a5e146292f377dd736f7/scanpy/_utils.py#L159
def get_igraph(sources_idxs_file, targets_idxs_file, weights_file, n_vertices):
    sources = np.load(sources_idxs_file) 
    targets = np.load(targets_idxs_file) 
    weights = np.load(weights_file)
    g = ig.Graph(directed=None) 
    g.add_vertices(n_vertices) # this adds adjacency.shap[0] vertices
    g.add_edges(list(zip(sources, targets))) 
    g.es['weight'] = weights       
    if g.vcount() != n_vertices: 
        print('WARNING: The constructed graph has only ' 
              +str(g.vcount())+' nodes. ' 
             'Your adjacency matrix contained redundant nodes.') 
    return g 


if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    parser.add_argument("--sources_idxs_file", required=True)    
    parser.add_argument("--targets_idxs_file", required=True)    
    parser.add_argument("--weights_file", required=True)    
    parser.add_argument("--n_vertices", type=int, required=True) 
    parser.add_argument("--partition_type", required=True) 
    parser.add_argument("--n_iterations", type=int, required=True) 
    parser.add_argument("--initial_membership_file", default=None,
                        required=False) 
    parser.add_argument("--seed", type=int, required=True) 

    args = parser.parse_args()

    the_graph = get_igraph(
        sources_idxs_file=args.sources_idxs_file,
        targets_idxs_file=args.targets_idxs_file,
        weights_file=args.weights_file,
        n_vertices=args.n_vertices)

    partition_type = eval("leidenalg."+args.partition_type) 
    n_iterations = args.n_iterations
    initial_membership = (None if args.initial_membership_file is None
                          else np.load(args.initial_membership_file))
    seed = args.seed

    partition = leidenalg.find_partition(
        graph=the_graph,
        partition_type=partition_type,
        weights=(np.array(the_graph.es['weight']).astype(np.float64)),                              
        n_iterations=n_iterations,                      
        initial_membership=initial_membership,                     
        seed=seed) 

    quality = partition.quality()
    print("########################")
    print("Quality:",quality)
    print("Membership:")
    print("\n".join(str(x) for x in partition.membership))
    sys.stdout.flush() 
