#!/usr/bin/env python
import regexploit.hook

regexploit.hook.install()

import importlib
import pkgutil
import sys

from regexploit.ast.sre import SreOpParser
from regexploit.redos import find
from regexploit.output.text import TextOutput

# Load python modules and process regexes which are compiled on import by hooking re.compile


def main():
    def onerror(name):
        print("Cannot load", name)

    names = tuple(sys.argv[1:]) if len(sys.argv) > 1 else None
    sys.argv = sys.argv[:1]
    if names:
        regexploit.hook.regexes.clear()

    output = TextOutput()
    for p in pkgutil.walk_packages(sys.path, onerror=onerror):
        # Importing some modules is disruptive https://xkcd.com/353/
        if (
            not names
            and p.name not in ("antigravity", "rstpep2html", "setup")
            and not p.name.startswith(("test", "pip", "setuptools", "idlelib", "rst2"))
            and not p.name.endswith(("__main__", ".main", ".conftest"))
            and ".test" not in p.name
        ) or (names and p.name.startswith(names)):
            print(f"Importing {p.name}")
            try:
                importlib.import_module(p.name)
                hooked_regex: regexploit.hook.CompiledRegex
                for hooked_regex in regexploit.hook.get_and_clear_regexes():
                    output.next()
                    parsed = SreOpParser().parse_sre(
                        hooked_regex.pattern, hooked_regex.flags
                    )
                    for redos in find(parsed):
                        if redos.starriness > 2:
                            output.record(
                                redos,
                                hooked_regex.pattern,
                                filename=hooked_regex.last_tb.filename,
                                lineno=hooked_regex.last_tb.lineno,
                                context=hooked_regex.last_tb.line,
                            )

            except Exception as e:
                print("Cannot load", p, e)
    print(f"Processed {output.regexes} regexes")


if __name__ == "__main__":
    main()
