#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function
import numpy as np
from intervul.readpos import VulcanPosMesh, AddResults
from intervul.vulInterVtk import meshToVtkFormat
from intervul.writeVTK import WriteVTK
import os
import argparse


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


parser = argparse.ArgumentParser(
    description='Convierte un archivo .pos de Vulcan a archivos vtk (muchos)')
parser.add_argument(
    '-o',
    '--fileOutput',
    help='Indica el nombre del archivo de salida (default: timeLapse.txt)',
    const="timeLapse.txt",
    default=None,
    action='store',
    nargs='?')
parser.add_argument(
        '-p',
        '--principalStress',
        help='Indica si se debe calcular los esfuerzos principales',
        default=True,
        type=str2bool)
parser.add_argument(
        '-s',
        '--principalStrain',
        help='Indica si se debe calcular las deformaciones principales',
        default=True,
        type=str2bool)

parser.add_argument(
        '-l',
        '--principalStretch',
        help='Indica si se debe calcular los alargamientos principales',
        default=True,
        type=str2bool)
parser.add_argument('fileIn', help="Archivo de dato")
args = parser.parse_args()

filename = args.fileIn
fileout = args.fileOutput
doCalcPrincipalStress = args.principalStress
doCalcPrincipalStrains = args.principalStrain
doCalcPrincipalStretch = args.principalStretch
# print(fileout)

newResults = [] #AQUÍ SE GUARDAN NUEVOS RESULTADOS..... [AddResult(...), AddResult(...), AddResult(...), .....]

def calcPrincipalStress(result):
    # TODO se asume tension plana para el caso 2D
    if result['stress'] is not None:
        ndata, nvals = result['stress'].shape
        if nvals == 4:
            stressVoigt = np.hstack((result['stress'], np.zeros((ndata, 1))))
            stressTensor = stressVoigt[:, [0, 2, 4, 2, 1, 4, 4, 4, 3]]
        if nvals == 6:
            stressVoigt = result['stress']
            stressTensor = stressVoigt[:, [0, 2, 4, 2, 1, 5, 4, 5, 3]]
            # print("StressTensor")
            # print(stressTensor[0])
        stressTensor.shape = (ndata, 3, 3)
        try:
            eigvals, eigvec = np.linalg.eigh(stressTensor)
        except np.linalg.LinAlgError:
            return [None, None, None, None, None, None]
        stressI = eigvals[:, 2]
        stressII = eigvals[:, 1]
        stressIII = eigvals[:, 0]
        vecStressI = eigvec[:, :, 2]
        vecStressII = eigvec[:, :, 1]
        vecStressIII = eigvec[:, :, 0]
        return [
            stressI, stressII, stressIII, vecStressI, vecStressII, vecStressIII
        ]
    else:
        return [None, None, None, None, None, None]

def calcPrincipalStrains(result):
    # TODO casos 2D
    if result['strain'] is not None:
        ndata, nvals = result['strain'].shape
        if nvals == 4:
            raise Exception("No implementado las deformaciones en 2D")
        if nvals == 6:
            strainVoigt = result['strain']
            strainTensor = strainVoigt[:, [0, 2, 4, 2, 1, 5, 4, 5, 3]]
        strainTensor[:, [1,2,3,5,6,7]] /= 2.0
        strainTensor.shape = (ndata, 3, 3)
        try:
            eigvals, eigvec = np.linalg.eigh(strainTensor)
        except np.linalg.LinAlgError:
            return [None, None, None, None, None, None]
        strainI = eigvals[:, 2]
        strainII = eigvals[:, 1]
        strainIII = eigvals[:, 0]
        vecStrainI = eigvec[:, :, 2]
        vecStrainII = eigvec[:, :, 1]
        vecStrainIII = eigvec[:, :, 0]
        return [
            strainI, strainII, strainIII, vecStrainI, vecStrainII, vecStrainIII
        ]
    else:
        return [None, None, None, None, None, None]


def calcPrincipalStretch(result):
    # TODO se asume tension plana para el caso 2D
    if result['strain'] is not None:
        ndata, nvals = result['strain'].shape
        if nvals == 4:
            raise Exception("No implementado las deformaciones en 2D")
        if nvals == 6:
            stretchVoigt = result['strain']
            stretchTensor = stretchVoigt[:, [0, 2, 4, 2, 1, 5, 4, 5, 3]]

            ##############################################
            ################## C = 2E + I ################
            stretchTensor = 2*stretchTensor
            stretchTensor[:, [0, 4, 8]] += 1.0

            #Voigt to tensor
            stretchTensor[:,[1, 2, 3, 5, 6, 7]] /= 2.0
            ##############################################

        stretchTensor.shape = (ndata, 3, 3)
        try:
            eigvals, eigvec = np.linalg.eigh(stretchTensor)
        except np.linalg.LinAlgError:
            return [None, None, None, None, None, None]

        ####################################################
        ############ lambda = sqrt(lambda^2) > C ###########
        eigvals = np.sqrt(eigvals)
        stretchI = eigvals[:, 2]
        stretchII = eigvals[:, 1]
        stretchIII = eigvals[:, 0]
        vecStretchI = eigvec[:, :, 2]
        vecStretchII = eigvec[:, :, 1]
        vecStretchIII = eigvec[:, :, 0]
        J = np.sqrt(np.linalg.det(stretchTensor))
        # print(J)
        return [
            stretchI, stretchII, stretchIII, vecStretchI, vecStretchII, vecStretchIII,J
        ]
    else:
        return [None, None, None, None, None, None, None]


if doCalcPrincipalStress:
    newResults.append(
        AddResults(
            ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
            ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
            ['StressI', 'StressII', 'StressIII', 'StressI vec', 'StressII vec',
             'StressIII vec'],
            calcPrincipalStress)
    )

if doCalcPrincipalStrains:
    newResults.append(
        AddResults(
            ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors'],
            ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
            ['StrainI', 'StrainII', 'StrainIII', 'StrainI vec', 'StrainII vec',
             'StrainIII vec'],
            calcPrincipalStrains)
            )
if doCalcPrincipalStretch:
    newResults.append(
        AddResults(
            ['scalars', 'scalars', 'scalars', 'vectors', 'vectors', 'vectors','scalars'],
            ['nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal', 'nodal'],
            ['StretchI', 'StretchII', 'StretchIII', 'StretchI vec', 'StretchII vec',
             'StretchIII vec', 'J'],
            calcPrincipalStretch)
            )


base = os.path.basename(filename)
name = os.path.splitext(base)[0]
data = VulcanPosMesh(filename, VulcanPosMesh.MECHANICAL, newResults=newResults)
origMesh = data.mesh

submeshs = meshToVtkFormat(origMesh)
nsets = len(submeshs)

if fileout:
    myoutput = open(fileout, 'w')


class IteratorsResults:
    def __init__(self, data):
        self.data = data

    def time(self):
        for self.istep, self.allResult in self.data:
            text = ("Interval:" + "{:3d}".format(self.allResult['itime']) +
                    " Step:" + "{:5d}".format(self.allResult['istep']) +
                    " Iteration:" + "{:4d}".format(self.allResult['iiter']) +
                    " Time:" + "{:10.5f}".format(self.allResult['TimeValue']))
            print(text)
            if fileout:
                myoutput.write(text + "\n")
            yield self.allResult['TimeValue']

    def sets(self):
        for iset, mesh in submeshs.items():
            partialResult = self.allResult.updateBySplitMesh(mesh)
            resultsNodal = partialResult.getResults(transformTo3D=True,
                                                    which="nodal")
            resultsCell = partialResult.getResults(transformTo3D=True,
                                                   which="cell")
            resultsField = partialResult.getResults(transformTo3D=True,
                                                    which="field")

            for result, idata in resultsField.items():
                if type(idata) is not np.ndarray:
                    resultsField[result] = np.array(idata)
            del resultsField['titleResult']
            yield {
                'x': mesh.x,
                'y': mesh.y,
                'z': mesh.z,
                'connectivity': mesh.connectivity,
                'offsets': mesh.offset,
                'cell_types': mesh.ctype,
                'cellData': resultsCell,
                'pointData': resultsNodal,
                'fieldData': resultsField,
            }


myiter = IteratorsResults(data)
WriteVTK(name, myiter.time, myiter.sets, nsets)
