#!/usr/bin/env python3

import sys
import os
import time
import json
import docker
import threading
import argparse
from argparse import ArgumentTypeError as err
from base64 import b64decode
import watchdog.events
import watchdog.observers
import time
from pathlib import Path
import logging

DOCKER_LABLE = "com.github.ravensorb.traefik-certificate-exporter.domain-restart"

###########################################################################################################
###########################################################################################################
settings = {
    "watchPath": "./",
    "fileSpec": "*.json",
    "outputPath": "./certs",
    "traefikResolverId": None,
    "flat": False,
    "dryRun": False,
    "restartContainers": False,
    "domains": {
        "include": [],
        "exclude": []
    }
}
###########################################################################################################

class AcmeCertificateExporter:
    def __init__(self, settings : dict):

        self.__settings = settings

    def extractCertificates(self, sourceFile : str):
        data = json.loads(open(sourceFile).read())

        keys = "uppercase"
        if self.__settings["traefikResolverId"] and len(self.__settings["traefikResolverId"]) > 0:
            data = data[self.__settings["traefikResolverId"]]
            keys = "lowercase"

        # Should we try to get the first resolver if it is there?
        # if "DomainsCertificate" not in data and "Certificates" not in data:
        #     data = data[0]

        # Determine ACME version
        acme_version = 2 if 'acme-v02' in data['Account']['Registration']['uri'] else 1

        # Find certificates
        if acme_version == 1:
            certs = data['DomainsCertificate']['Certs']
        elif acme_version == 2:
            certs = data['Certificates']

        # Loop over all certificates
        names = []

        for c in certs:
            if acme_version == 1:
                name = c['Certificate']['Domain']
                privatekey = c['Certificate']['PrivateKey']
                fullchain = c['Certificate']['Certificate']
                sans = c['Domains']['SANs']
            elif acme_version == 2:
                if keys == "uppercase":
                    name = c['Domain']['Main']
                    privatekey = c['Key']
                    fullchain = c['Certificate']
                    sans = c['Domain']['SANs']
                else:
                    name = c['domain']['main']
                    privatekey = c['key']
                    fullchain = c['certificate']
                    sans = c['domain']['sans'] if'sans' in c['domain'] else []  # not sure what this is - can't find any here...

            if (self.__settings["domains"]["include"] and name not in self.__settings["domains"]["include"]) or (self.__settings["domains"]["exclude"] and name in self.__settings["domains"]["exclude"]):
                continue

            # Decode private key, certificate and chain
            privatekey = b64decode(privatekey).decode('utf-8')
            fullchain = b64decode(fullchain).decode('utf-8')
            start = fullchain.find('-----BEGIN CERTIFICATE-----', 1)
            cert = fullchain[0:start]
            chain = fullchain[start:]

            if not self.__settings["dryRun"]:
                # Create domain     directory if it doesn't exist
                directory = Path(self.__settings["outputPath"])
                if not directory.exists():
                    directory.mkdir()

            if self.__settings["flat"]:
                # Write private key, certificate and chain to flat files
                with (directory / (str(name) + '.key')).open('w') as f:
                    f.write(privatekey)

                with (directory / (str(name) + '.crt')).open('w') as f:
                    f.write(fullchain)

                with (directory / (str(name) + '.chain.pem')).open('w') as f:
                    f.write(chain)

                # if sans:
                #     for name in sans:
                #         with (directory / (str(name) + '.key')).open('w') as f:
                #             f.write(privatekey)
                #         with (directory / (str(name) + '.crt')).open('w') as f:
                #             f.write(fullchain)
                #         with (directory / (str(name) + '.chain.pem')).open('w') as f:
                #             f.write(chain)
            else:
                directory = directory / name
                if not directory.exists():
                    directory.mkdir()

                # Write private key, certificate and chain to file
                with (directory / 'privkey.pem').open('w') as f:
                    f.write(privatekey)

                with (directory / 'cert.pem').open('w') as f:
                    f.write(cert)

                with (directory / 'chain.pem').open('w') as f:
                    f.write(chain)

                with (directory / 'fullchain.pem').open('w') as f:
                    f.write(fullchain)

            logging.info("Extracted certificate for: {} ({})".format(name, ', '.join(sans) if sans else ''))

            names.append(name)

        return names

###########################################################################################################

class AcmeCertificateFileHandler(watchdog.events.PatternMatchingEventHandler):
    def __init__(self, exporter : AcmeCertificateExporter, settings : dict):
        self.__exporter = exporter
        self.__settings = settings

        self.isWaiting = False
        self.lock = threading.Lock()

        # Set the patterns for PatternMatchingEventHandler
        watchdog.events.PatternMatchingEventHandler.__init__(self, patterns = [ self.__settings["fileSpec"] ],
                                                                    ignore_directories = True, 
                                                                    case_sensitive = False)

   
    def on_created(self, event):
        logging.debug("Watchdog received created event - % s." % event.src_path)
        self.handleEvent(event)

    def on_modified(self, event):
        logging.debug("Watchdog received modified event - % s." % event.src_path)
        self.handleEvent(event)

    def handleEvent(self, event):

        if not event.is_directory:
            logging.info("Certificates changed found in file: {}".format(event.src_path))

            with self.lock:
                if not self.isWaiting:
                    self.isWaiting = True # trigger the work just once (multiple events get fired)
                    self.timer = threading.Timer(2, self.doTheWork, args=[event])
                    self.timer.start()

    def doTheWork(self, *args, **kwargs):
        ''' 
        This is a workaround to handle multiple events for the same file
        '''
        logging.debug("DEBUG : starting the work")

        if not args or len(args) == 0:
            logging.error("No event passed to worker")
            self.isWaiting = False

            return

        domains = self.__exporter.extractCertificates(args[0].src_path)

        if (self.__settings["restartContainers"]):
            try:
                self.restartContainerWithDomains(domains)
            except Exception as ex:
                logging.warn("Unable to restart containers", exc_info=True)

        with self.lock:
            self.isWaiting = False
        
        logging.debug('DEBUG : finished')

    def restartContainerWithDomains(self, domains):
        client = docker.from_env()
        container = client.containers.list(filters = {"label" : DOCKER_LABLE})
        for c in container:
            restartDomains = str.split(c.labels[ DOCKER_LABLE ], ',')
            if not set(domains).isdisjoint(restartDomains):
                logging.info("Restarting container: {}".format(c.id))
                if not self.__settings["dry"]:
                    c.restart()

###########################################################################################################
###########################################################################################################

if __name__ == "__main__":
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.DEBUG)

    logging.info("Traefik Cretificate Exporter starting....")

    ###########################################################################################################
    parser = argparse.ArgumentParser(description="Extract traefik letsencrypt certificates.")

    parser.add_argument("-c", "--config-file", dest="configFile", default=None, type=str,
                                help="the path to watch for changes (default: %(default)s)")
    parser.add_argument("-wp", "--watch-path", dest="watchPath", default=settings["watchPath"], type=str, 
                                help="the path to watch for changes (default: %(default)s)")
    parser.add_argument("-fs", "--file-spec", dest="fileSpec", default=settings["fileSpec"], type=str, 
                                help="file that contains the traefik certificates (default: %(default)s)")
    parser.add_argument("-od", "--output-directory", dest="outputPath", default=settings["outputPath"], type=str, 
                                help="The folder to exports the certificates in to (default: %(default)s)")
    parser.add_argument("--traefik-resolver-id", dest="traefikResolverId", default=settings["traefikResolverId"],
                                help="Traefik certificate-resolver-id.")
    parser.add_argument("-f", "--flat", action="store_true", dest="flat",
                                help="If specified, all certificates into a single folder")
    parser.add_argument("-r", "--restart_container", action="store_true", dest="restartContainer",
                                help="If specified, any container that are labeled with '" + DOCKER_LABLE + "=<DOMAIN>' will be restarted if the domain name of a generated certificates matches the value of the lable. Multiple domains can be seperated by ','")
    parser.add_argument("--dry-run", action="store_true", dest="dry", 
                                help="Don't write files and do not restart docker containers.")

    group = parser.add_mutually_exclusive_group()
    group.add_argument("-id", "--include-domains", nargs="*", dest="includeDomains", default=None,
                                help="If specified, only certificates that match domains in this list will be extracted")
    group.add_argument("-xd", "--exclude-domains", nargs="*", dest="excludeDomains", default=None,
                                help="If specified. certificates that match domains in this list will be ignored")
    
    ###########################################################################################################

    args = parser.parse_args()

    # Do we need to load settings from a config file
    if args.configFile and os.path.exists(args.configFile):
        logging.info("Loading Confgile: {}".format(args.configFile))
        settings = json.loads(open(args.configFile).read())

    # Letts override the settings from the dommain line
    settings.update(watchPath=args.watchPath)
    settings.update(fileSpec=args.fileSpec)
    settings.update(outputPath= args.outputPath)
    settings.update(traefikResolverId= args.traefikResolverId)

    settings["flat"] = args.flat
    settings["restartContainers"] = args.restartContainer
    settings["dryRun"] = args.dry

    if args.includeDomains:
        settings["domains"]["include"] = args.includeDomains
    if args.excludeDomains:
        settings["domains"]["exclude"] = args.excludeDomains

    # Lets validate the path we are being asked to watch actually exists
    if not os.path.exists(settings["watchPath"]):
        logging.error("Watch Path does not exist. Exiting...")
        sys.exit(-1)

    logging.info("Watching Path: {}".format(settings["watchPath"]))
    logging.info("File Spec: {}".format(settings["fileSpec"]))
    logging.info("Output Path: {}".format(settings["outputPath"]))

    exporter = AcmeCertificateExporter(settings=settings)

    event_handler = AcmeCertificateFileHandler(exporter=exporter, 
                                               settings=settings)

    observer = watchdog.observers.Observer()
    observer.schedule(event_handler, path=settings["watchPath"], recursive=False)

    observer.start()
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        observer.stop()
    observer.join()

    logging.info("Traefik Cretificate Exporter stopping....")
