#!/usr/bin/env python
# encoding: utf-8

import sys, array
from itertools import count

ALPHABET = "ab"
states = count(1)

def hash_set(s):
    return "".join(map(str, list(s)))

def merge_dictionaries(d1, d2):
    """Merges two dictionaries."""
    d1 = d1.copy()
    for key,values in d2.iteritems():
        x = d1.get(key, [])
        x += values
        d1[key] = x
    return d1

class NFA:
    def __init__(self):
        self.Q = set()
        self.F = set()
        self.q0 = None
        self.d = {}
        self.re = ""
    
    def copy(self):
        n = DFA()
        n.Q, n.F, n.q0, n.d = self.Q, self.F, self.q0, self.d.copy()
        return n
    
    def addtrans(self, state, delta):
        """Adds another transition to the NFA."""
        sd = self.d.get(state, [])
        sd.append(delta)
        self.d[state] = sd
    
    def __str__(self):
        """Returns a string representation of the NFA."""
        
        def make_fn_table(d):
            s = []
            for state, transitions in d.iteritems():
                for letter, newstate in transitions:
                    s.append("    %6s: %1s -> %6s" % (state, letter, newstate))
            return "\n".join(s)
        
        s = []
        s.append("    Q = {" + ", ".join(map(str, self.Q)) + "}")
        s.append("    q0 = %s" % self.q0)
        s.append("    F = {" + ", ".join(map(str, self.F)) + "}")
        s.append("    d = ")
        s.append(make_fn_table(self.d))
        return "\n".join(s)
    
    def toDFA(self):
        """Returns an DFA version of the NFA."""
        union = lambda x,y: x.union(y)
        
        def find_transitions(r, a):
            return [q for x,q in self.d.get(r, ()) if x == a]

        def null_closure(states):
            """Finds the null closure of a list of states."""
            # The null closure is actually the fixed point of F(x)
            nt = lambda r: find_transitions(r, "") # finds null transitions
            
            def fixed_point(f, guess):
                if f(guess) == guess: return guess
                return fixed_point(f, f(guess))
            
            def F(x):
                # The union of {states} and all the null transitions 
                # from s in5 states
                return reduce(union, [nt(s) for s in x], set(states))
            
            return fixed_point(F, set())
        
        d = DFA()
        queue = []
        q0prime = null_closure(set([self.q0]))
        d.q0 = hash_set(q0prime)
        queue.append(q0prime)
        
        from operator import add
        alphabet = set([a for a,s in reduce(add, self.d.values(), [])])
        
        while len(queue) > 0:
            R = queue.pop(0)
            d.Q.add(hash_set(R))
            if len(R.intersection(self.F)) > 0:
                d.F.add(hash_set(R))
            deltaprime = set()
            for a in alphabet:
                if a == "": continue
                _d = set()
                for r in R:
                    _d = _d.union(null_closure(find_transitions(r, a)))
                if len(_d) > 0:
                    if hash_set(_d) not in d.Q: 
                        queue.append(_d)
                    d.addtrans(hash_set(R), (a, hash_set(_d)))      
        
        return d.simplify()
    
    def simplify(self):
        from itertools import groupby
        q_all = self.Q
        equivalent = groupby(list(self.Q), lambda x: str(self.d.get(x)))
        removed = set()
        renamed = {}
        for a,b in equivalent:
            b = tuple(b)
            name = b[0]
            renamed[name] = name
            for deadstate in b[1:]:
                removed.add(deadstate)
                renamed[deadstate] = name
        if len(removed) == 0: 
            return self.renumber()
        else:
            return self.rename_states(renamed).simplify()
    
    def renumber(self):
        state_gen = count(1)
        return self.rename_states(dict((q, state_gen.next()) for q in self.Q))
    
    def rename_states(self, renamed):
        nfa = self.copy()
        nfa.Q = set(renamed[q] for q in nfa.Q)
        nfa.F = set(renamed[q] for q in nfa.F)
        nfa.q0 = renamed[nfa.q0]
        newdelta = {}
        for state,transitions in nfa.d.iteritems():
            state = renamed[state]
            newdelta[state] = [(a, renamed[qprime]) for a,qprime in transitions]
        nfa.d = newdelta
        return nfa

class DFA(NFA):
    """Represents a DFA."""
    def accepts(self, inpt):
        """Tries the inpt in the machine."""
        def delta(r, c):
            """Finds the next state from the current state."""
            f = self.d.get(r, [])
            for a,s in f:
                if a == c: return s
            return None # trap state
        
        state = self.q0
        for c in inpt:
            state = delta(state, c)
        return state != None and state in self.F
    
    def toDFA(self):
        return self

def mempty(*args):
    """Returns an NFA that accepts the empty string only."""
    n = NFA()
    n.q0 = states.next()
    n.Q.add(n.q0)
    n.F.add(n.q0)
    n.re = ""
    return n

def mnone(*args):
    """Returns an NFA that represents the null language."""
    n = NFA()
    n.q0 = states.next()
    n.Q.add(n.q0)
    n.re = "NULL"
    return n

def exact(letter):
    """Returns an NFA that accepts only a letter."""
    n = NFA()
    n.q0 = states.next()
    n.Q.add(n.q0)
    finish = states.next()
    n.Q.add(finish)
    n.F.add(finish)
    n.addtrans(n.q0, (letter, finish))
    n.re = letter
    return n

def union(nfa1, nfa2):
    """Unions two different NFAs."""
    n = NFA()
    n.q0 = states.next()
    n.Q = nfa1.Q.union(nfa2.Q)
    n.Q.add(n.q0)
    n.d = merge_dictionaries(nfa1.d, nfa2.d)
    n.addtrans(n.q0, ("", nfa1.q0))
    n.addtrans(n.q0, ("", nfa2.q0))
    n.F = nfa1.F.union(nfa2.F)
    n.re = "(%s) U (%s)" % (nfa1.re, nfa2.re)
    return n

def star(nfa):
    """Stars an NFA."""
    n = NFA()
    n.q0 = states.next()
    n.Q = nfa.Q.copy()
    n.Q.add(n.q0)
    n.F.add(n.q0)
    n.d = nfa.d.copy()
    n.addtrans(n.q0, ("", nfa.q0))
    for f in nfa.F:
        n.addtrans(f, ("", n.q0))
    n.re = "(%s)*" % nfa.re
    return n

def concat(nfa1, nfa2):
    """Concatenates two NFAs."""
    n = NFA()
    n.Q = nfa1.Q.union(nfa2.Q)
    n.F = nfa2.F.copy()
    n.q0 = nfa1.q0
    n.d = merge_dictionaries(nfa1.d, nfa2.d)
    for f in nfa1.F:
        n.addtrans(f, ("", nfa2.q0))
    n.re = "(%s) o (%s)" % (nfa1.re, nfa2.re)
    return n

def tree2infix(tree):
    """Returns an expression tree in infix form."""
    def paren(expr):
        if len(expr) <= 1: return expr
        return '(' + expr + ')'
    from types import ListType
    if type(tree) != ListType: return paren(tree)
    op = tree[0]
    if len(tree) > 2: 
        return (" " + op + " ").join(map(paren, map(tree2infix, tree[1:])))
    return paren(tree2infix(tree[1])) + op

def compiletree(tree, indent = 0):
    """Converts an expression tree into an NFA."""
    # if the expression is empty, return the empty string language
    if tree == '':
        print "Creating an exact dfa for the empty string."
        return mempty()
    
    # if this is a letter, return exactly it.
    if tree == "a" or tree == "b":
        retval = exact(tree)
        print "Created DFDA for '%s':\n%s\n" % (tree, retval)
        return retval
    
    op = tree[0]
    if op == 'u' or op == 'U':   f = union
    elif op == '*': f = star
    elif op == 'o' or op == 'O': f = concat
    else:           f = mnone
    
    # To compile this expression, we 
    # first compile each of the subexpressions
    # and then apply the current operation onto
    # these new machines.
    compound_nfa = f(*map(compiletree, tree[1:]))
    print "Created NFA for %s:\n%s\n" % (compound_nfa.re, compound_nfa)
    return compound_nfa

def parseregex(expr):
    """Converts an expression into an expression tree."""
    def num_params(op): 
        if op in 'uUoO': return 2
        elif op == '*': return 1
        else: return 0
    
    if expr == '': return [expr]
    
    operands = []
    expr = array.array('c', expr)
    expr.reverse()
    for token in expr:
        if token == " ": continue
        if token == "a" or token == "b":
            operands.append(token)
        elif token in "uUoO*":
            if len(operands) < num_params(token):
                raise Exception("Missing an operand: %s at char %d (%s)." % \
                                 (token, index, ", ".join(operands)))
            branch = [operands.pop() for i in range(0, num_params(token))]
            
            # NOTE: in order to make this parser act like the examples given
            # (where o lists its first operator second), I have to check this
            # if statement:
            # TODO:? (reversing the branch shouldn't be necessary)
            if token == 'o':
                branch = [token] + branch[::-1]
            else:
                branch = [token] + branch
            operands.append(branch)
        else:
            if int(token) > 128: continue # ignore unicode stuff
            raise Exception("Invalid token: %s." % token)
    if len(operands) != 1:
        raise Exception("Too many operands: %s." % operands)    
    return operands

# Handles the command line arguments
if __name__ == '__main__':
    print "Enter regex: ",
    regex = raw_input().replace('u','U')
    tree = parseregex(regex)
    if len(tree) > 1:
        raise Exception("More than one expression in regex: %s." % tree)
    
    print "Infix: %s" % tree2infix(tree[0])
    print "Output in output.txt."
    realstdout = sys.stdout
    sys.stdout = open("output.txt", "w")
    print "RE2DFA output (Stephen Roller)"
    print "Input: %s" % regex 
    print "Infix: %s" % tree2infix(tree[0])
    print
    nfa = compiletree(tree[0])
    print "Minimized DFA: "
    dfa = nfa.toDFA()
    print dfa
    
    sys.stdout.close()
    sys.stdout = realstdout
    
    print "Final DFA: "
    print dfa
    print
    while True:
        try:
            word = raw_input("(^c to exit) input?: ")
            if dfa.accepts(word):
                print "Accept"
            else:
                print "Reject"
        except (KeyboardInterrupt, EOFError):
            print 
            break
