#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function
import numpy as np
from intervul.readpos import VulcanPosMesh, AddResults
from intervul.vulInterVtk import meshToVtkFormat
from intervul.writeVTK import WriteVTK
import os
import argparse


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


parser = argparse.ArgumentParser(
    description='Convierte un archivo .pos de Vulcan a archivos vtk (muchos)')
parser.add_argument(
    '-o',
    '--fileOutput',
    help='Indica el nombre del archivo de salida (default: timeLapse.txt)',
    const="timeLapse.txt",
    default=None,
    action='store',
    nargs='?')

parser.add_argument(
        '-p',
        '--principalStress',
        help='Indica si se debe calcular los esfuerzos principales',
        default=True,
        type=str2bool)
parser.add_argument(
        '-s',
        '--principalStrain',
        help='Indica si se debe calcular las deformaciones principales',
        default=True,
        type=str2bool)

parser.add_argument(
        '-l',
        '--principalStretch',
        help='Indica si se debe calcular los alargamientos principales',
        default=True,
        type=str2bool)

parser.add_argument(
        '-pM',
        '--principalStressMag',
        help='Indica si se debe calcular los esfuerzos principales (ordenados por magnitud)',
        default=False,
        type=str2bool)
parser.add_argument(
        '-sM',
        '--principalStrainMag',
        help='Indica si se debe calcular las deformaciones principales (ordenados por magnitud)',
        default=False,
        type=str2bool)

parser.add_argument(
        '-lM',
        '--principalStretchMag',
        help='Indica si se debe calcular los alargamientos principales (ordenados por magnitud)',
        default=False,
        type=str2bool)


parser.add_argument('fileIn', help="Archivo de dato")
args = parser.parse_args()

filename = args.fileIn
fileout = args.fileOutput
doCalcPrincipalStress = args.principalStress
doCalcPrincipalStrains = args.principalStrain
doCalcPrincipalStretch = args.principalStretch
doCalcPrincipalStressMag = args.principalStressMag
doCalcPrincipalStrainsMag = args.principalStrainMag
doCalcPrincipalStretchMag = args.principalStretchMag
# print(fileout)

newResults = [] #AQUÍ SE GUARDAN NUEVOS RESULTADOS..... [AddResult(...), AddResult(...), AddResult(...), .....]

def calcPrincipalStress(result):
    # TODO se asume tension plana para el caso 2D
    if result['stress'] is not None:
        ndata, nvals = result['stress'].shape
        if nvals == 4:
            stressVoigt = np.hstack((result['stress'], np.zeros((ndata, 1))))
            stressTensor = stressVoigt[:, [0, 2, 4, 2, 1, 4, 4, 4, 3]]
        if nvals == 6:
            stressVoigt = result['stress']
            stressTensor = stressVoigt[:, [0, 2, 4, 2, 1, 5, 4, 5, 3]]
        stressTensor.shape = (ndata, 3, 3)

        try:
            eigvals, eigvec = np.linalg.eigh(stressTensor)
         #   p = np.argsort(np.abs(eigvals),axis=-1) #Crea array con el orden de los eigvals en valor abs
         #   eigvals = np.take_along_axis(eigvals, p, axis=-1) #Ordena eigvals en el orden p
          #  for i in range(3): #esto hay que hacerlo con variable de dim.
           #     eigvec[:,i] = np.take_along_axis(eigvec[:,i], p, axis=-1) #ordena cada vector en el orden p
        except np.linalg.LinAlgError:
            return [None, None, None, None, None, None]
        ##print("S")
        #print(eigvals[3])
        #print(eigvec[3][0])

        stressI = eigvals[:, 2]
        stressII = eigvals[:, 1]
        stressIII = eigvals[:, 0]
        vecStressI = eigvec[:, :, 2]
        vecStressII = eigvec[:, :, 1]
        vecStressIII = eigvec[:, :, 0]
        return [
            stressI, stressII, stressIII, vecStressI, vecStressII, vecStressIII
        ]
    else:
        return [None, None, None, None, None, None]

def calcPrincipalStrains(result):
    # TODO casos 2D
    if result['strain'] is not None:
        ndata, nvals = result['strain'].shape
        if nvals == 4:
            raise Exception("No implementado las deformaciones en 2D")
        if nvals == 6:
            strainVoigt = result['strain']
            strainTensor = strainVoigt[:, [0, 2, 4, 2, 1, 5, 4, 5, 3]]
        strainTensor[:, [1,2,3,5,6,7]] /= 2.0
        strainTensor.shape = (ndata, 3, 3)

        try:
            eigvals, eigvec = np.linalg.eigh(strainTensor)
        except np.linalg.LinAlgError:
            return [None, None, None, None, None, None]

        strainI = eigvals[:, 2]
        strainII = eigvals[:, 1]
        strainIII = eigvals[:, 0]
        vecStrainI = eigvec[:, :, 2]
        vecStrainII = eigvec[:, :, 1]
        vecStrainIII = eigvec[:, :, 0]
        return [
            strainI, strainII, strainIII, vecStrainI, vecStrainII, vecStrainIII
        ]
    else:
        return [None, None, None, None, None, None]


def calcPrincipalStretch(result):
    # TODO se asume tension plana para el caso 2D
    if result['strain'] is not None:
        ndata, nvals = result['stress'].shape
        if nvals == 4:
            raise Exception("No implementado las deformaciones en 2D")
        if nvals == 6:
            stretchVoigt = result['strain']
            stretchTensor = stretchVoigt[:, [0, 2, 4, 2, 1, 5, 4, 5, 3]]

            ##############################################
            ################## C = 2E + I ################
            stretchTensor = 2*stretchTensor
            stretchTensor[:, [0, 4, 8]] += 1.0

            #Voigt to tensor
            stretchTensor[:,[1, 2, 3, 5, 6, 7]] /= 2.0
            ##############################################

        stretchTensor.shape = (ndata, 3, 3)
        try:
            eigvals, eigvec = np.linalg.eigh(stretchTensor)
        except np.linalg.LinAlgError:
            return [None, None, None, None, None, None]

        ####################################################
        ############ lambda = sqrt(lambda^2) > C ###########
        eigvals = np.sqrt(eigvals)
        stretchI = eigvals[:, 2]
        stretchII = eigvals[:, 1]
        stretchIII = eigvals[:, 0]
        vecStretchI = eigvec[:, :, 2]
        vecStretchII = eigvec[:, :, 1]
        vecStretchIII = eigvec[:, :, 0]
        J = np.sqrt(np.linalg.det(stretchTensor))
        # print(J)
        return [
            stretchI, stretchII, stretchIII, vecStretchI, vecStretchII, vecStretchIII,J
        ]
    else:
        return [None, None, None, None, None, None, None]


def calcMagnitudeStress(result):
    ndata, nvals = result['strain'].shape
    if nvals == 4:
        raise Exception("No implementado en 2D")
    if nvals == 6:

        stress = np.stack((result['StressIII'],
                           result['StressII'],
                           result['StressI']), axis=1)
        stressVec = np.stack((result['StressIII vec'],
                           result['StressII vec'],
                           result['StressI vec']), axis=-1)

        p = np.argsort(np.abs(stress),axis=-1) #Crea array con el orden de los eigvals en valor abs
        stress = np.take_along_axis(stress, p, axis=-1) #Ordena eigvals en el orden p
        for i in range(len(stress[0])):
            stressVec[:,i] = np.take_along_axis(stressVec[:,i], p, axis=-1) #ordena cada vector en el orden p

        magn1 = stress[:,2]
        magn2 = stress[:,1]
        magn3 = stress[:,0]
        magnV1 = stressVec[:,:,2]
        magnV2 = stressVec[:,:,1]
        magnV3 = stressVec[:,:,0]
        return [magn1, magn2, magn3, magnV1, magnV2, magnV3]


def calcMagnitudeStrain(result):
    ndata, nvals = result['strain'].shape
    if nvals == 4:
        raise Exception("No implementado en 2D")
    if nvals == 6:

        Strain = np.stack((result['StrainIII'],
                           result['StrainII'],
                           result['StrainI']), axis=1)
        StrainVec = np.stack((result['StrainIII vec'],
                           result['StrainII vec'],
                           result['StrainI vec']), axis=-1)

        p = np.argsort(np.abs(Strain),axis=-1) #Crea array con el orden de los eigvals en valor abs
        Strain = np.take_along_axis(Strain, p, axis=-1) #Ordena eigvals en el orden p
        for i in range(len(Strain[0])):
            StrainVec[:,i] = np.take_along_axis(StrainVec[:,i], p, axis=-1) #ordena cada vector en el orden p

        magn1 = Strain[:,2]
        magn2 = Strain[:,1]
        magn3 = Strain[:,0]
        magnV1 = StrainVec[:,:,2]
        magnV2 = StrainVec[:,:,1]
        magnV3 = StrainVec[:,:,0]
        return [magn1, magn2, magn3, magnV1, magnV2, magnV3]


def calcMagnitudeStretch(result):
    ndata, nvals = result['strain'].shape
    if nvals == 4:
        raise Exception("No implementado en 2D")
    if nvals == 6:

        Stretch = np.stack((result['StretchIII'],
                           result['StretchII'],
                           result['StretchI']), axis=1)
        StretchVec = np.stack((result['StretchIII vec'],
                           result['StretchII vec'],
                           result['StretchI vec']), axis=-1)

        p = np.argsort(np.abs(Stretch),axis=-1) #Crea array con el orden de los eigvals en valor abs
        Stretch = np.take_along_axis(Stretch, p, axis=-1) #Ordena eigvals en el orden p
        for i in range(len(Stretch[0])):
            StretchVec[:,i] = np.take_along_axis(StretchVec[:,i], p, axis=-1) #ordena cada vector en el orden p

        magn1 = Stretch[:,2]
        magn2 = Stretch[:,1]
        magn3 = Stretch[:,0]
        magnV1 = StretchVec[:,:,2]
        magnV2 = StretchVec[:,:,1]
        magnV3 = StretchVec[:,:,0]
        return [magn1, magn2, magn3, magnV1, magnV2, magnV3]

if doCalcPrincipalStress:
    newResults.append(
        AddResults(
            ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
            ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
            ['StressI', 'StressII', 'StressIII', 'StressI vec', 'StressII vec',
             'StressIII vec'],
            calcPrincipalStress)
    )

if doCalcPrincipalStrains:
    newResults.append(
        AddResults(
            ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
            ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
            ['StrainI', 'StrainII', 'StrainIII', 'StrainI vec', 'StrainII vec',
             'StrainIII vec'],
            calcPrincipalStrains)
            )
if doCalcPrincipalStretch:
    newResults.append(
        AddResults(
            ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors','scalars'],
            ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
            ['StretchI', 'StretchII', 'StretchIII', 'StretchI vec', 'StretchII vec',
             'StretchIII vec', 'J'],
            calcPrincipalStretch)
            )
if doCalcPrincipalStressMag:
     newResults.append(
       AddResults(
           ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
           ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
           ['Stress1Mag', 'Stress2Mag', 'Stress3Mag', 'StressIMag Vec','StressIIMag Vec','StressIIIMag Vec'],
           calcMagnitudeStress)
       )

if doCalcPrincipalStrainsMag:
     newResults.append(
       AddResults(
           ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
           ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
           ['Strain1Mag', 'Strain2Mag', 'Strain3Mag', 'StrainIMag Vec','StrainIIMag Vec','StrainIIIMag Vec'],
           calcMagnitudeStrain)
       )

if doCalcPrincipalStretchMag:
     newResults.append(
       AddResults(
           ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
           ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
           ['Stretch1Mag', 'Stretch2Mag', 'Stretch3Mag', 'StretchIMag Vec','StretchIIMag Vec','StretchIIIMag Vec'],
           calcMagnitudeStretch)
       )

base = os.path.basename(filename)
name = os.path.splitext(base)[0]
data = VulcanPosMesh(filename, VulcanPosMesh.MECHANICAL, newResults=newResults)
origMesh = data.mesh

submeshs = meshToVtkFormat(origMesh)
nsets = len(submeshs)

if fileout:
    myoutput = open(fileout, 'w')


class IteratorsResults:
    def __init__(self, data):
        self.data = data

    def time(self):
        for self.istep, self.allResult in self.data:
            text = ("Interval:" + "{:3d}".format(self.allResult['itime']) +
                    " Step:" + "{:5d}".format(self.allResult['istep']) +
                    " Iteration:" + "{:4d}".format(self.allResult['iiter']) +
                    " Time:" + "{:10.5f}".format(self.allResult['TimeValue']))
            print(text)
            if fileout:
                myoutput.write(text + "\n")
            yield self.allResult['TimeValue']

    def sets(self):
        for iset, mesh in submeshs.items():
            partialResult = self.allResult.updateBySplitMesh(mesh)
            resultsNodal = partialResult.getResults(transformTo3D=True,
                                                    which="nodal")
            resultsCell = partialResult.getResults(transformTo3D=True,
                                                   which="cell")
            resultsField = partialResult.getResults(transformTo3D=True,
                                                    which="field")

            for result, idata in resultsField.items():
                if type(idata) is not np.ndarray:
                    resultsField[result] = np.array(idata)
            del resultsField['titleResult']
            yield {
                'x': mesh.x,
                'y': mesh.y,
                'z': mesh.z,
                'connectivity': mesh.connectivity,
                'offsets': mesh.offset,
                'cell_types': mesh.ctype,
                'cellData': resultsCell,
                'pointData': resultsNodal,
                'fieldData': resultsField,
            }


myiter = IteratorsResults(data)
WriteVTK(name, myiter.time, myiter.sets, nsets)
