#!/usr/bin/env python
#
# Copyright 2018 Rick Chang <chchang915@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys, os, netrc, json, subprocess, argparse, shlex, re
import xml.etree.ElementTree as etree
import ssl

try:
    import configparser as ConfigParser
    from urllib.parse import urlparse, urlencode, quote
    from urllib.error import HTTPError
    import urllib.request as request
except ImportError:
    import ConfigParser
    from urlparse import urlparse
    from urllib import urlencode, quote
    import urllib2 as request

CONFIG = {
    'repo': 'repo',
    'user': None,
    'password': None,
    'gerritUrl': 'https://review.gerrithub.io',
    'connectTimeout': '30',
    'netrcPath': '~/.netrc',
    'fetchProtocols': 'http ssh git',
    'configPath': '~/.pickrc',
    'delimiter': '-' * 80,
    'delimiterEnd': '=' * 80,
    'patchPreview': 'git log --no-decorate -1',
}

CUR_PATH    = os.getcwd()
RED         = "\x1b[31m"
GREEN       = "\x1b[32m"
YELLOW      = "\x1b[33m"
NONE        = "\x1b[0m"
CONFIG_SECTION = 'section'

def Log(s, endPat=None):
    print(s, end=endPat)
    sys.stdout.flush()

def Loge(s):
    Log(RED + s + NONE)

def Logm(s):
    Log(GREEN + s + NONE)

def Logw(s):
    Log(YELLOW + s + NONE)

def Logv(s, verbose):
    if verbose:
        Log(s)

class FakeSection(object):
    def __init__(self, f, name):
        self.f = f
        self.first = True
        self.sectionName = "[%s]\n" % name

    def readline(self):
        if self.first:
            self.first = False
            return self.sectionName
        return self.f.readline()

    def __iter__(self):
        if self.first:
            self.first = False
            yield self.sectionName
        line = self.f.readline()
        while line:
            yield line
            line = self.f.readline()

class Config(object):
    def __init__(self, opts):
        self.gerrit = opts.gerrit if opts.gerrit else CONFIG['gerritUrl']
        self.repo = CONFIG['repo']
        self.user = opts.user if opts.user else CONFIG['user']
        self.password = opts.password if opts.password else CONFIG['password']
        self.timeout = int(CONFIG['connectTimeout'])
        self.netrc_file = os.path.expanduser(opts.netrc_file if opts.netrc_file else CONFIG['netrcPath'])
        self.protocols = CONFIG['fetchProtocols']
        self.delim = CONFIG['delimiter']
        self.delim_end = CONFIG['delimiterEnd']
        self.preview = opts.preview if opts.preview else CONFIG['patchPreview']
        self.opts = opts

def InitConfig(opts):
    config_file = CONFIG['configPath']
    if opts.config_file:
        config_file = opts.config_file

    config_file = os.path.expanduser(config_file)
    if not os.path.isfile(config_file):
        if opts.config_file:
            Loge("Can't find config file '%s'" % config_file)
            return None
        return Config(opts)

    Log("Use config file: %s" % config_file)
    config = ConfigParser.SafeConfigParser()
    config.readfp(FakeSection(open(config_file), CONFIG_SECTION))
    for key, val in CONFIG.items():
        if not config.has_option(CONFIG_SECTION, key):
            continue
        val = config.get(CONFIG_SECTION, key)
        CONFIG[key] = val
    return Config(opts)

def InitAccount(configs):
    opts = configs.opts

    if opts.netrc_file is not None and not os.path.isfile(configs.netrc_file):
        Loge("Can't find netrc file '%s'" % (configs.netrc_file))
        return 1

    if configs.user is not None or configs.password is not None:
        return 0

    configs.user, configs.password = GetLoginInfoNetrc(configs.gerrit, configs.netrc_file)
    if configs.user is None:
        Logw("WARNING: user is None")
    if configs.password is None:
        Logw("WARNING: password is None")
    return 0

def RunShell(cmd):
    return subprocess.call(cmd, shell=True)

def Run(cmd):
    try:
        process = subprocess.Popen(shlex.split(cmd), bufsize=0)
        process.communicate()
        return process.returncode
    except Exception as e:
        Loge("Run cmd '%s' fail. (%s)" % (cmd, e))
        return 1

def RunLog(cmd, silent=False, outErr=True):
    errout = subprocess.STDOUT if outErr else None
    try:
        ret = subprocess.check_output(shlex.split(cmd),
                                      stderr=errout)
        return ret.decode('utf-8').strip()
    except Exception as e:
        if not silent:
            Loge("Run '%s' fail. (%s)" % (cmd, e))
        return None

class GerritRest:
    # https://github.com/GerritCodeReview/gerrit/blob/master/Documentation/rest-api.txt
    def __init__(self, url, username, password):
        url = url.strip('/')
        password_mgr = request.HTTPPasswordMgrWithDefaultRealm()
        password_mgr.add_password(None, url, username, password)
        digest_auth = request.HTTPDigestAuthHandler(password_mgr)
        basic_auth = request.HTTPBasicAuthHandler(password_mgr)
        opener = request.build_opener(digest_auth, basic_auth)
        request.install_opener(opener)
        self.server = url

    def query(self, q, verbose=False, timeout=None):
        headers = {
            "Accept": "application/json",
            "Content-Type": "application/json;charset=UTF-8"
        }

        try:
            _create_unverified_https_context = ssl._create_unverified_context
        except AttributeError:
            pass
        else:
            ssl._create_default_https_context = _create_unverified_https_context
        url = "%s/a%s" % (self.server, q)
        Logv("Query: %s" % url, verbose)
        try:
            req = request.Request(url, None, headers)
            response = request.urlopen(req, None, timeout)
            # skip )]}'
            response.readline()
            data = json.loads(response.read().decode("utf-8"))
            Logv("Response: %s" % data, verbose)
        except Exception as e:
            Loge("query fail: %s" % e)
            return None
        return data

def GetMachine(url):
    info = urlparse(url)
    machine = info.netloc
    if info.port:
        return machine.replace(":%s" % info.port, "")
    return machine

def GetLoginInfoNetrc(url, path):
    if not os.path.isfile(path):
        return None, None

    handle = netrc.netrc(path)
    machine = GetMachine(url)
    if not machine:
        Log("Can't find machine name in '%s' for netrc" % (url))
        return None, None
    info = handle.authenticators(machine)
    if not info:
        Log("Can't find machine '%s' in '%s'" % (machine, path))
        return None, None
    (login, account, password) = info
    return (login, password)

def ParseArguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('change_num', nargs='*',
                        help = 'ex. \'12345\', \'12345/1\'')
    parser.add_argument('-u', '--user', type=str,
                        help='gerrit user id')
    parser.add_argument('-p', '--password', type=str,
                        help='gerrit HTTP password')
    # https://review.openstack.org/Documentation/user-search.html
    parser.add_argument('-q', '--query', type=str,
                        help='get patches from query command (change_num arguments will be ignored if any) ex. \'branch:master status:merged after:"2018-11-17 22:06:00"\'')
    parser.add_argument('--query-only', action='store_true', default=False,
                        help='do not install patch')
    parser.add_argument('-r', '--preview', type=str,
                        help='preview command for changes (default: %s)' % CONFIG['patchPreview'])
    parser.add_argument('-g', '--gerrit', type=str,
                        help='gerrit server url (default: %s)' % CONFIG['gerritUrl'])
    parser.add_argument('-d', '--dryrun', action='store_true', default=False,
                        help='show what would be done')
    parser.add_argument('-n', '--netrc-file', type=str,
                        help='netrc path (default: %s). (if user or password has been specified, netrc config will be ignored)' % CONFIG['netrcPath'])
    parser.add_argument('-c', '--config-file', type=str,
                        help='config path (default: %s)' % CONFIG['configPath'])
    parser.add_argument('-m', '--manifest', type=str, metavar='NAME.xml',
                        help='assign manifest file to resolve patch install path instead of using repo command')
    parser.add_argument('-i', '--install-path', type=str,
                        help='assign patch install path instead of resolving path by repo command or manifest')
    parser.add_argument('-F', '--full-path', action='store_true', default=False,
                        help='display the full install path instead of the relative install path')
    parser.add_argument('-N', '--name-path', action='store_true', default=False,
                        help='display the project name instead of the relative install path')
    parser.add_argument('-x', '--exec', dest='exe', type=str,
                        help='append command after all changes installed in each project')
    parser.add_argument('--verbose', action='store_true', default=False,
                        help='show more logs')
    parser.add_argument('-v', '--version', action='version', version='1.1.1')
    opts = parser.parse_args(argv[1:])

    if not opts.change_num and not opts.query:
        parser.print_help()
        return None
    return opts

def ChangeIdToRefId(change_num, patchset_id):
    hash_id = str(int(change_num) % 100)
    if len(hash_id) == 1:
        hash_id = "0" + hash_id
    return "refs/changes/%s/%s/%s" % (hash_id, change_num, patchset_id)

def GetRemote(protocols, fetch_info):
    remote = ''
    for prot in protocols:
        if prot in fetch_info:
            remote = fetch_info[prot]['url']
            break
    return remote

def ResponseToPatch(configs, response, change_num, patchset_id):
    info = response[0]
    project = info['project']
    current_revision = info['current_revision']
    if not patchset_id:
        patchset_id = info['revisions'][current_revision]['_number']
    fetch_info = info['revisions'][current_revision]['fetch']

    remote = GetRemote(re.split('\s+', configs.protocols), fetch_info)
    if not remote:
        Loge("Can't find remote project for %s" % change_num)
        Log(fetch_info)
        return None

    patch = {'project': project, 'remote': remote, 'patch': "%s/%s" % (change_num, patchset_id), 'ref': ChangeIdToRefId(change_num, patchset_id)}
    return patch

def GetChangeIdInfo(change_num):
    patchset_id = ''
    if '/' in change_num:
        compound_id = change_num.split('/')
        if len(compound_id) != 2:
            return None, None
        change_num = compound_id[0]
        patchset_id = compound_id[1]

    if not change_num.isdigit():
        return None, None
    if patchset_id and not patchset_id.isdigit():
        return None, None
    return change_num, patchset_id

def QueryChanges(configs, rest):
    opts = configs.opts
    query = '/changes/?q=%s' % quote(opts.query);

    response = rest.query(query, opts.verbose, configs.timeout)
    Logv(configs.delim, opts.verbose)
    if response is None:
        return None
    if not response:
        return []
    for res in response:
        Log("%8s - %s" % (res['_number'], res['subject']))
    Log("(Total: %d changes)" % len(response))
    Log(configs.delim_end)
    return [str(res['_number']) for res in reversed(response)]

def GetPatchList(configs, rest, change_nums):
    opts = configs.opts
    patch_list = {}
    for input_num in change_nums:
        change_num, patchset_id = GetChangeIdInfo(input_num)
        if not change_num:
            Loge("error: Unknown change number format '%s'" % input_num)
            return None

        query = '/changes/?q=%s&o=%s' % (change_num, "CURRENT_REVISION")
        Log("Searching patch '%s' ..." % change_num, '')
        Logv("\n%s" % configs.delim, opts.verbose)
        response = rest.query(query, opts.verbose, configs.timeout)
        if response is None:
            Logv("fail", not opts.verbose)
            return None
        if not response:
            Logv("fail", not opts.verbose)
            Loge("Can't find change number '%s'" % change_num)
            return None
        Log(configs.delim_end if opts.verbose else "done")

        patch = ResponseToPatch(configs, response, change_num, patchset_id)
        if not patch:
            return None

        patch['input_num'] = input_num
        project = patch['project']
        patch_list[project] = patch_list.get(project, [])
        patch_list[project].append(patch)
    return patch_list

def InstallPatch(configs, remote, ref_id):
    opts = configs.opts
    if Run("git fetch %s %s" % (remote, ref_id)):
        return 1

    if opts.dryrun:
        return Run("%s %s" % (configs.preview, "FETCH_HEAD"))
    return Run("git cherry-pick FETCH_HEAD")

def GetInputNums(patches, start = 0):
    ret = ""
    for i in range(start, len(patches)):
        ret += patches[i]['input_num'] + " "
    return ret.strip()

def CheckInstallPath(full_path, project, patches, configs):
    opts = configs.opts
    method = "manifest '%s'" % opts.manifest if opts.manifest else "command '%s list'" % configs.repo
    error_msg = "The install path for the project can't be found."
    if not full_path:
        error_msg = "The install path can't be resolved by %s" % method
    elif not os.path.exists(full_path):
        error_msg = "The install path '%s' is not found." % full_path
    else:
        return True

    Logm("[Project: %s]" % project)
    Logw('''\
WARNING: %s
    (use '-i INSTALL_PATH' for '%s' to assign the right path)
    (use the following commands manually in the right path)''' % (error_msg,
                                                                  GetInputNums(patches)))
    Log(configs.delim)
    for patch in patches:
        Log("git fetch %s %s && git cherry-pick FETCH_HEAD" % (patch['remote'], patch['ref']))
    Log(configs.delim_end)
    return False

def InstallPatches(configs, project, patches, root):
    opts = configs.opts
    if opts.install_path is not None:
        path = opts.install_path
        full_path = os.path.abspath(path)
    elif opts.manifest is not None:
        ele = root.find("./project[@name='%s']" % project)
        path = ele.get('path') if ele is not None else None
        full_path = os.path.abspath(path) if path is not None else None
    else:
        path = RunLog("%s list -p %s" % (configs.repo, project), silent=True, outErr=False)
        full_path = RunLog("%s list -pf %s" % (configs.repo, project), silent=True, outErr=False)

    if not CheckInstallPath(full_path, project, patches, configs):
        return 1

    if opts.full_path:
        display_path = full_path
    elif opts.name_path:
        display_path = project
    else:
        display_path = path
    Logm("[%s]" % display_path.strip('/'))

    os.chdir(full_path)
    fail = 0
    for i, patch in enumerate(patches):
        remote = patch['remote']
        ref_id = patch['ref']

        if opts.dryrun:
            Log("(dryrun) Pick: %s %s" % (remote, ref_id))
        else:
            Log("Pick: %s %s" % (remote, ref_id))
        Log(configs.delim)
        ret = InstallPatch(configs, remote, ref_id)
        Log(configs.delim_end)
        if ret:
            Loge("error: pick fail (unfinished patches: %s)" % GetInputNums(patches, i))
            fail = 1
            break
    if not fail and opts.exe:
        RunShell(opts.exe)
    os.chdir(CUR_PATH)
    return fail

def main(argv):
    opts = ParseArguments(argv)
    if opts is None:
        return 1

    configs = InitConfig(opts)
    if configs is None:
        return 1

    if InitAccount(configs):
        return 1
 
    if opts.verbose:
        Log(configs.delim)
        Log(vars(configs))
        Log(configs.delim_end)

    change_nums = opts.change_num
    rest = GerritRest(configs.gerrit, configs.user, configs.password)
    if opts.query:
        Log("Querying change numbers from '%s'" % configs.gerrit)
        Log("Searching patches by '%s' ..." % opts.query)
        Log(configs.delim)
        change_nums = QueryChanges(configs, rest)
        if change_nums is None:
            return 1
        if not change_nums:
            Log("Can't find any patches")
            return 0

    if opts.query_only:
        return 0

    root = None;
    if opts.manifest:
        try:
            root = etree.parse(opts.manifest).getroot()
        except Exception as e:
            Log("Manifest '%s' parse error: %s" % (opts.manifest, e))
            return 1

    Log("Getting patches from '%s'" % configs.gerrit)
    Log(configs.delim)
    patch_list = GetPatchList(configs, rest, change_nums)
    Logv(configs.delim_end, not opts.verbose)
    if not patch_list:
        return 1

    Log("Fetching and installing patches")
    Log(configs.delim)
    ret = 0
    for project, patches in patch_list.items():
        if InstallPatches(configs, project, patches, root):
            ret = 1
    return ret

if __name__ == '__main__':
    sys.exit(main(sys.argv))
