#!/usr/bin/env python
from collections import OrderedDict
from datetime import datetime, date, timedelta
from dateutil.relativedelta import relativedelta
from ply.lex import TOKEN
import calendar
import code
import locale
import re
import sys
import time
import types

HELP = '''
OBJECTS

    DELTA  

            a timedelta object can be interpreted as 
            chain of (amount, unit) consisting of a 
            number followed by a time unit in ISO format, 
            with case input relaxed except for 
            differentiating months and minutes:

                1D+1d  # case insensitive
                3m+3M  # except for months and minutes
               -2M+2s  # accepts negative numbers
                10Y3s  # join them together instead of adding

    DATETIME

            a datetime object represents a point in 
            time. Can be interpreted in various forms 
            such as follows:

                1611269086 # unix timestamp in seconds
                2020/12/31 22:22
                2020 Jan 12
                2020 December 20
                2020/12/31 22:22:22
                2020/12/31
                today

VARIABLES

            there are three built-in variables:

                T or today     
                YD or yesterday
                TM or tomorrow 
                N  or now      

            but you can also assign objects to a 
            named variable, like so:
                foo=1d
                bar=YD

OPERATORS
            +  : adds deltas to points in time
            -  : takes de difference between two points 
                 in time and stores a delta

 <,<=,>,==,!=  : compares two points in time and returns
                 a boolean

KEYWORDS
            in :  

FUNCTIONS / ATTRIBUTES 
              
    wait DELTA       : sleeps for the duration

    next WEEKDAY     : returns the date for the next weekday

    last WEEKDAY     : returns the date for the last weekday

dayofweek TIME_POINT : returns weekday for time point
    dow  TIME_POINT  : ⏎
    TIME_POINT.dow   : ⏎
'''

n = None

replace = lambda replacee,replacement,string: re.sub(replacee, replacement, string)

days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
months = ['January', 'February', 'March', 'April', 'May', 'June', \
            'July', 'August', 'September', 'October', 'November', 'December']
days_abbrev = [d[:3] for d in days]
months_abbrev = [m[:3] for m in months]

tokens = (
    'PLUS','MINUS','EQUALS',
    'LPAREN','RPAREN',
    'YD', # yesterday
    'N',  # now
    'T',  # today
    'TM', # tomorrow
    'UNIT',
    'CYCLIC_OP',
    'IN',
    'GT', 'GE', 'LT', 'LE', 'EQ', 'NE',
    'NAME',
    'DELTA', 
    'TIMESTAMP', 
    'WEEKDAY', 
    'PERIOD', 
    'SEMICOLON', 
    'DATETIME', 
    )

# Tokens

t_SEMICOLON = r';'
t_PERIOD    = r'\.'
t_PLUS      = r'\+'
t_MINUS     = r'-'
t_EQUALS    = r'='
t_LPAREN    = r'\('
t_RPAREN    = r'\)'
t_GT        = r'>'
t_GE        = r'>='
t_LT        = r'<'
t_LE        = r'<='
t_EQ        = r'=='
t_NE        = r'!='
t_IN        = r'(?i)in'

unit_map = {
        's': 'seconds',
        'S': 'seconds',
        'M': 'minutes',
        'h': 'hours',
        'H': 'hours',
        'd': 'days',
        'D': 'days',
        'w': 'weeks',
        'W': 'weeks',
        'm': 'months',
        'y': 'years',
        'Y': 'years',
        }

reserved = [
            r'in',
            r'n(ow)?',
            r't(oday)?',
            r'next|last',
            r'yd|yesterday',
            r'seconds|minutes|hours|days|weeks',
            '|'.join(days),
            '|'.join(days_abbrev),
            '|'.join(months),
            '|'.join(months_abbrev),
        ]

def is_reserved(k):
    for r in reserved:
        if re.match(r, k):
            return True

UNITS_STR = 'sSMhHdDwWmyY'
FLOATING_POINT = r'((?:\d*[.])?\d+)'
DELTA_TOKEN = FLOATING_POINT + '['+UNITS_STR+']('+FLOATING_POINT+'(['+UNITS_STR+']|$))*'
# print(DELTA_TOKEN)
@TOKEN(DELTA_TOKEN)
def t_DELTA(t):
    units_vals = OrderedDict()
    matches = re.findall(FLOATING_POINT + '(['+UNITS_STR+']|$)', t.value)
    for v,u in matches:
        units_vals.update({ u if u in 'mM' else u.lower(): float(v) if v else 1 })
    t.value = parse_units(units_vals)
    if '' in units_vals:
        u,_ = list(units_vals.items())[list(units_vals.keys()).index('')-1]
        next_unit = UNITS_STR[UNITS_STR.index(u)-1]
        if next_unit.lower() == u.lower():
            next_unit = UNITS_STR[UNITS_STR.index(u)-2]
        le_add = parse_units({next_unit: units_vals['']}) # get successor delta
        t.value += le_add
    return t

t_NAME = '(?i)(' + ''.join([f'(?!{res})' for res in reserved]) + '[a-z_][a-z0-9_]*)'

@TOKEN(r'(?i)' + '|'.join(days) + '|' + '|'.join(days_abbrev))
def t_WEEKDAY(t):
    t.value = Weekday(t.value)
    return t

@TOKEN(r'(?i)' + '|'.join(months) + '|' + '|'.join(months_abbrev))
def t_MONTH(t):
    t.value = Month(t.value)
    return t

def t_CYCLIC_OP(t):
    r'(?i)(next|last)'
    return t

def t_UNIT(t):
    r'(?i)(seconds|minutes|hours|days|weeks|months|years|unix)'
    return t

def t_YD(t):
    r'(?i)(yesterday|yd)'
    t.value = (datetime.today() - timedelta(days=1)).date()
    return t

def t_TM(t):
    r'(?i)(tomorrow|tm)'
    t.value = (datetime.today() + timedelta(days=1)).date()
    return t

def t_T(t):
    r'(?i)t(oday)?'
    t.value = date.today()
    return t

def t_N(t):
    r'(?i)n(ow)?'
    global n
    if not n:
        n = datetime.now()
    t.value = n
    return t

def get_closest_week_day(week_day):
    counter_next = 0
    counter_prev = 0
    next_date = datetime.now()
    for _ in range(7):
        next_date += timedelta(days=1)
        counter_next += 1
        if days[next_date.weekday()].lower() == week_day.lower():
            break
    prev_date = datetime.now()
    for _ in range(7):
        prev_date += timedelta(days=-1)
        counter_prev += 1
        if days[prev_date.weekday()].lower() == week_day.lower():
            break
    if counter_next < counter_prev:
        return next_date.date()
    return prev_date.date()

REGEX_DOY = r'(?:\d+[^\d:]\d+[^\d:]\d+|\d+\s+(?:' + '|'.join(months) + '|' + '|'.join(months_abbrev) + ')\s+\d+)'
REGEX_0_23 = r'(?:2[0-3]|1?[0-9])'
REGEX_0_59 = r'(?:[1-5]?[0-9])'
REGEX_24_59 = r'(?:[3-5][0-9]|2[4-9])'

DATETIME_REGEX = f'(?:({REGEX_DOY})\s?|(' + REGEX_0_23 + '[hHM]?):(' + REGEX_0_59 + f'[MsS]?)(?::({REGEX_0_59}))?)' + '{1,2}'
@TOKEN(DATETIME_REGEX)
def t_DATETIME(t):
    date_str,*time = re.search(DATETIME_REGEX, t.value).groups()
    if t.value.count(':') == 1 and \
            not any(c in t.value for c in 'hHMsS'):
                raise Exception(f'did you mean {time[0]}h:{time[1]} or {time[0]}M:{time[1]}?')
    if date_str:
        if any(month in date_str for month in months_abbrev):
            y,b,d = date_str.split(' ')
            try:
                date = datetime.strptime(f'{y.zfill(4)}-{b}-{d}', '%Y-%b-%d')
            except:
                try:
                    date = datetime.strptime(f'{y.zfill(4)}-{b}-{d}', '%Y-%B-%d')
                except:
                    raise Exception(f'Invalid syntax: {date_str}')
        else:
            y,m,d = replace('\D', '-', date_str).split('-')
            date = datetime.strptime(f'{y.zfill(4)}-{m}-{d}', '%Y-%m-%d').date()
        if all(time_unit is None for time_unit in time):
            t.value = date
            return t
    else:
        date = datetime.today()
    is_HMS = t.value.count(':') == 2
    is_HM = 'h' in time[0] or \
            'H' in time[0] or \
            'M' in time[1] or \
            (time[2] is not None and \
                ('s' in time[2] or \
                'S' in time[2]))
    if is_HMS:
        H,M,S = [replace('[hhMsS]', '', time_unit) for time_unit in time]
        le_time = datetime.strptime(f'{H}:{M}:{S}', '%H:%M:%S').time()
    elif is_HM:
        H,M = [replace(r'[hhMsS]', '', time_unit) for time_unit in time if time_unit]
        le_time = datetime.strptime(f'{H}:{M}', '%H:%M').time()
    else:
        M,S = [replace('[hhMsS]', '', time_unit) for time_unit in time if time_unit]
        le_time = datetime.strptime(f'{M}:{S}', '%M:%S').time()

    if not date_str:
        t.value = le_time
    else:
        t.value = datetime.combine(date,le_time)
    return t

def parse_units(units_vals):
    parsed = timedelta()
    for unit,long_name in unit_map.items():
        if unit in units_vals:
            if unit.lower() == 'y':
                parsed += (datetime.now()+relativedelta(years=units_vals[unit])) - datetime.now()
            elif unit == 'm':
                parsed += (datetime.now()+relativedelta(months=units_vals[unit])) - datetime.now()
            else:
                parsed += timedelta(**{long_name: units_vals[unit]})
    return parsed

def t_TIMESTAMP(t):
    r'\d+'
    t.value = datetime.fromtimestamp(int(t.value))
    return t

t_ignore = ' \t'
t_ignore_COMMENT = r'\#.*'

def t_newline(t):
    r'\n+'
    t.lexer.lineno += t.value.count('\n')

def t_error(t):
    print(f'illegal character {t.value[0]!r}')
    t.lexer.skip(1)

import ply.lex as lex
lex.lex(debug=False)

def wait(t):
    if isinstance(t, datetime):
        now = datetime.now()
        delta = t - now
    elif isinstance(t, timedelta):
        delta = t
    else:
        raise Exception('wait accepts a time point or time delta only')
    if delta > timedelta(0):
        time.sleep(delta.total_seconds())

def last_wd(t):
    prev_date = datetime.now()
    for _ in range(7):
        prev_date += timedelta(days=-1)
        if days[prev_date.weekday()].lower() == str(t).lower():
            break
    return prev_date.date()

def cyclic(t, direction):
    cyclic_direction = datetime.now()
    for _ in range(7):
        cyclic_direction += timedelta(days=direction)
        if days[cyclic_direction.weekday()].lower() == str(t).lower():
            break
    return cyclic_direction.date()

def dow(t):
    if type(t) == date or type(t) == datetime:
        return days[t.weekday()]
    elif type(t) == timedelta:
        return days[(datetime.now()+t).weekday()]
    else:
        raise Exception('can\'t get day of week of object of type' + str(type(t)))

def is_000(obj):
    return obj.hour == obj.minute == obj.second == 0 if type(obj) == datetime \
            else type(obj) == date 

class Weekday:
    def __init__(self, name):
        self.name = name

    def __str__(self):
        return self.name

class Month:
    def __init__(self, name):
        self.name = name

    def __str__(self):
        return self.name

names = {
            'day'       : lambda t                           : t.day,
            'month'     : lambda t                           : t.month,
            'year'      : lambda t                           : t.year,
            'hour'      : lambda t                           : t.hour,
            'minute'    : lambda t                           : t.minute,
            'second'    : lambda t                           : t.second,

            'wait'      : lambda t                           : wait(t),
            'dow'       : lambda t                           : dow(t),
            'dayofweek' : lambda t                           : dow(t),
            'weekday'   : lambda t                           : dow(t),
            'help'      : lambda                             : print(HELP),
        }

precedence = (
    ('left',
        'IN', 
        ),
    ('right',
        'UMINUS',
        ),
    ('left',
        'PLUS',
        'MINUS', 
        ),
    )

def p_statements(p):
    'statement : statement SEMICOLON statement'

def p_statement_invalid_assignment(p):
    'statement : WEEKDAY EQUALS expression'
    raise Exception('can\'t assign expression to weekday')

def p_statement_assign(p):
    'statement : NAME EQUALS expression'
    n = None
    if is_reserved(p[1]):
        raise Exception('can\'t use reserved keyword')
    names[p[1]] = p[3]

def normalize(t):
    if type(t) == datetime and \
            is_000(t):
        t = t.date()
    return t

def p_statement_expr(p):
    'statement : expression'
    if type(p[1]) is Weekday:
        p[1] = get_closest_week_day(str(p[1]))
    if p[1] is not None:
        print(normalize(p[1]))
        names['_'] = p[1]

def p_expression_binop(p):
    '''expression : expression PLUS expression
                  | expression MINUS expression
                  '''
    if type(p[1]) == date and type(p[3]) == timedelta:
        p[1] = datetime.combine(p[1], datetime.min.time())

    if type(p[3]) == date and type(p[1]) == timedelta:
        p[3] = datetime.combine(p[3], datetime.min.time())

    if type(p[1]) == date and type(p[3]) == datetime:
        p[1] = datetime.combine(p[1], datetime.min.time())

    if type(p[3]) == date and type(p[1]) == datetime:
        p[3] = datetime.combine(p[3], datetime.min.time())

    if p[1] is None or p[3] is None:
        raise Exception(f'in {p[2]} expr, p[1]={p[1]} and p[3]={p[3]}')

    if type(p[1]) == Weekday:
        p[1] = get_closest_week_day(str(p[1]))
    if type(p[3]) == Weekday:
        p[3] = get_closest_week_day(str(p[3]))
    if   p[2] == '+': 
        p[0] = p[1] + p[3]
    elif p[2] == '-': 
        p[0] = p[1] - p[3]

def p_expression_comparison(p):
    '''expression : expression GT expression
                  | expression LT expression
                  | expression GE expression
                  | expression LE expression
                  | expression EQ expression
                  | expression NE expression
                  '''

    if type(p[1]) == date and type(p[3]) == datetime:
        p[1] = datetime.combine(p[1], datetime.min.time())

    if type(p[3]) == date and type(p[1]) == datetime:
        p[3] = datetime.combine(p[3], datetime.min.time())

    try:
        if p[2] == '<':
            p[0] = p[1] < p[3]
        if p[2] == '>':
            p[0] = p[1] > p[3]
        if p[2] == '>=':
            p[0] = p[1] >= p[3]
        if p[2] == '<=':
            p[0] = p[1] <= p[3]
        if type(p[1]) == datetime and \
                is_000(p[1]):
                    p[1] = p[1].date()
        if type(p[3]) == datetime and \
                is_000(p[3]):
                    p[3] = p[3].date()
        if p[2] == '==':
            p[0] = p[1] == p[3]
        if p[2] == '!=':
            p[0] = p[1] != p[3]
    except TypeError as e:
        print(str(e))

def p_expression_funcall(p):
    'expression : NAME expression'
    try:
        p[0] = names[p[1].lower()](p[2])
    except:
        raise Exception("Undefined name '%s'" % p[1])

def p_expression_cyclic_op(p):
    'expression : CYCLIC_OP WEEKDAY'
    p[0] = cyclic(p[2], 1 if p[1] == 'next' else -1)

def p_expression_point_in_unit(p):
    'expression : expression IN UNIT'
    to_unix = p[3] == 'unix'
    if type(p[1]) == timedelta and to_unix:
        raise Exception('can\'t convert timedelta to unix timestamp')
    if to_unix:
        p[0] = int(time.mktime(p[1].timetuple()))
    if p[3].lower() in unit_map.values():
        total_seconds = p[1].total_seconds()
        if p[3] == 'seconds':
            p[0] = total_seconds
        if p[3] == 'minutes':
            p[0] = total_seconds / 60
        if p[3] == 'hours':
            p[0] = total_seconds / 60 / 60
        if p[3] == 'days':
            p[0] = total_seconds / 60 / 60 / 24
        if p[3] == 'weeks':
            p[0] = total_seconds / 60 / 60 / 24 / 7

def p_expression_generic(p):
    '''expression : DELTA 
                  | point
                  '''
    p[0] = p[1]

def p_point(p):
    '''point : N
             | T
             | TIMESTAMP
             | DATETIME
             | TM
             | WEEKDAY
             | YD
             '''
    p[0] = p[1]

def p_expression_name(p):
    'expression : NAME'
    try:
        p[0] = names[p[1]]
        if(type(p[0]) is types.LambdaType):
            p[0] = p[0]()
    except LookupError:
        try:
            p[0] = names[p[3]](p[1])
        except LookupError:
            print("Undefined name '%s'" % p[1])
            p[0] = 0

def p_expression_get_attribute(p):
    'expression : expression PERIOD NAME'
    p[0] = names[p[3]](p[1])

def p_expression_group(p):
    'expression : LPAREN expression RPAREN'
    p[0] = p[2]

def p_expression_uminus(p):
    'expression : MINUS expression %prec UMINUS'
    p[0] = -p[2]


import ply.yacc as yacc
yacc.yacc()

def interactive():
    import cmd
    class CmdParse(cmd.Cmd):
        prompt = ''
        commands = []
        def default(self, line):
            if line == 'EOF':
                exit(0)
            yacc.parse(line)
            self.commands.append(line)
        def do_help(self, line):
            print(HELP)
        def do_exit(self, line):
            return True
    CmdParse().cmdloop()

if __name__ == '__main__':
    if len(sys.argv) > 1:
        yacc.parse(' '.join(sys.argv[1:]))
    else:
        interactive()
