#!/usr/bin/env python2
# -*- coding: utf-8 -*-

#
#    LinOTP - the open source solution for two factor authentication
#    Copyright (C) 2010 - 2019 KeyIdentity GmbH
#
#    This file is part of LinOTP server.
#
#    This program is free software: you can redistribute it and/or
#    modify it under the terms of the GNU Affero General Public
#    License, version 3, as published by the Free Software Foundation.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU Affero General Public License for more details.
#
#    You should have received a copy of the
#               GNU Affero General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#
#    E-mail: linotp@keyidentity.com
#    Contact: www.linotp.org
#    Support: www.keyidentity.com
#

'''
This script can be used to convert special Gemalto PSKC files to a simple
HOTP CSV file, that can be imported into LinOTP.

parameter
    --file <Gemalto-PSKC.xml>

The script expects the values to be PBE-AES128-CBC encrypted.
'''
from getopt import getopt, GetoptError
import sys
import re
import getpass
import base64
import linotp.lib.pbkdf2 as pbkdf2
from linotp.lib.ImportOTP.PSKC import aes_decrypt
import xml.etree.cElementTree as etree
import binascii

def usage():
    print __doc__

def main():
    try:
        opts, args = getopt(sys.argv[1:], "hf:",
                ["help", "file="])

    except GetoptError:
        print "There is an error in your parameter syntax:"
        usage()
        sys.exit(1)


    filename = ""

    for opt, arg in opts:
        if opt in ('--help', '-h'):
            usage()
            sys.exit(0)
        elif opt in ('--file', '-f'):
            filename = arg


    if "" == filename:
        print "missing filename!"
        usage()
        sys.exit(1)

    f = open(filename, 'r')
    xml = f.read()
    f.close()

    TOKENS = {}
    XML = etree.fromstring(xml)
    container_tag = XML.tag

    match = re.match("^({.*?})SecretContainer$", container_tag)
    namespace = ""
    if match:
        namespace = match.group(1)
        print("Found namespace %s" % namespace)

    version = XML.get("version")
    element_list = list(XML)

    ''' Determine all encryption information '''
    encryptionMethod = XML.find(namespace + "EncryptionMethod")
    if encryptionMethod.get("algorithm") != "PBE-AES128-CBC":
        print "We only support PBE-AES128-CBC. We found: %s" % encryptionMethod.get("algorithm")
        sys.exit(2)

    salt = encryptionMethod.find(namespace + "PBESalt").text
    iteration = encryptionMethod.find(namespace + "PBEIterationCount").text
    iv = encryptionMethod.find(namespace + "IV").text
    print "Encryption parameters:"
    print " salt: %s" % salt
    print " iterations: %s" % iteration
    print " iv: %s" % iv

    digestMethod = XML.find(namespace + "DigestMethod").get("algorithm")

    password = getpass.getpass("Please enter password for XML file: ")
    ''' calculating encryption key '''
    key_len = 16
    ENCRYPTION_KEY_bin = pbkdf2.pbkdf2(password.encode('ascii'), base64.b64decode(salt),
                    int(key_len), int(iteration))
    ENCRYPTION_KEY_hex = binascii.hexlify(ENCRYPTION_KEY_bin)

    print "We derived the encryption key of %s bytes" % len(ENCRYPTION_KEY_bin)
    #print "enc key: %s " % ENCRYPTION_KEY_hex

    ''' get all tokens '''
    for element in element_list:
        if element.tag == namespace + "Device":
            l = list(element)
            ''' get serial number '''
            id = element.find(namespace + "DeviceId")
            serial = id.find(namespace + "SerialNo").text

            ''' get parameters '''
            secret = element.find(namespace + "Secret")
            secret_id = secret.get("SecretId")
            algorithm = secret.get("SecretAlgorithm")

            if serial != secret_id:
                print "Error: serial %s does not match SecretId %s" % (serial, secret_id)

            algo_id = secret.find(namespace + "Usage").find(namespace + "AlgorithmIdentifier")
            cryptofunction = algo_id.find(namespace + "CryptoFunction").text
            if cryptofunction != "HMAC-SHA1":
                print "We do not support %s for serial %s" % (cryptofunction, serial)
            truncation = algo_id.find(namespace + "Truncation").text

            ''' get seed '''
            seed = ""
            l_secret = list(secret)
            for data in l_secret:
                d_tag = data.tag
                if d_tag == namespace + "Data":
                    if data.get("Name") == "SECRET":
                        value_b64 = data.find(namespace + "Value").text
                        valud_digest_b64 = data.find(namespace + "ValueDigest").text
                        # TODO: check the digest...
                        decrypt = base64.b64decode(iv) + base64.b64decode(value_b64)
                        print "decrypting %s" % base64.b64encode(decrypt)
                        seed = aes_decrypt(base64.b64encode(decrypt), ENCRYPTION_KEY_hex)
                        seed = binascii.hexlify(seed)

            TOKENS[serial] = (seed, algorithm, truncation)

    for token in TOKENS:
        (seed, algorithm, truncation) = TOKENS[token]
        print "%s, %s, %s, %s" % (token, seed, algorithm, truncation)


if __name__ == '__main__':
    main()
