#!python

# Run BSP code using threads to emulate processors
#
# Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
# last revision: 2004-3-29
#

import code, getopt
import ihooks, imp, pdb, traceback
import threading
import copy, os, operator, sys, time, types
from StringIO import StringIO
import Scientific.BSP

debug = 0

#
# Forced termination exception
#
class ForcedTermination(Exception):
    pass

#
# The virtual parallel machine
#
class VirtualBSPMachine:

    def __init__(self, nproc, show_runtime=0):
        pids = range(nproc)
        self.nproc = nproc
        self.show_runtime = runtime
        self.current_pid = 0
        self.debugger = None
        self.name_spaces = [{'__name__': '__main__'} for pid in pids]
        self.mailbox = [[] for pid in pids]
        self.runtime = [0. for pid in pids]
        self.output = [None for pid in pids]
        self.exception = [None for pid in pids]
        self.condition = threading.Condition()
        self.active_thread = -1

    def currentPid(self):
        return self.current_pid

    def numberOfProcessors(self):
        return nproc

    def isExtendedName(self, name):
        return name[:2] == '_P' and name[7] == '_'

    def extendedName(self, name, pid = None):
        if pid is None:
            pid = self.current_pid
        return ("_P%05d_" % pid) + name

    def baseName(self, name):
        if self.isExtendedName(name):
            return name[8:]
        else:
            return name

    def virtualProcessor(self, pid):
        if debug: print "*** Starting processor", pid
        self.condition.acquire()
        while self.active_thread != pid:
            self.condition.wait()
        self.condition.release()
        self.exception[pid] = None
        if self.debugger is not None:
            if self.current_pid in self.debugged_pids:
                print "*** Debugging pid", self.current_pid
                self.debugger.reset()
                sys.settrace(self.debugger.trace_dispatch)
            else:
                print "*** Not debugging pid", self.current_pid
        try:
            exec self.code in self.name_spaces[pid]
        except:
            self.exception[pid] = sys.exc_info()
            self.showtraceback()
        if self.debugger is not None:
            sys.settrace(None)
        self.terminated[pid] = True
        self.condition.acquire()
        self.active_thread = -1
        self.condition.notifyAll()
        self.condition.release()

    def run(self, code, buffered_output = 0):
        self.code = code
        pids = range(self.nproc)
        self.done = False
        self.terminated = [False for pid in pids]
        self.active_thread = -1
        for pid in pids:
            if buffered_output:
                self.output[pid] = StringIO()
            thread = threading.Thread(name = 'Processor %d' % pid,
                                      target = lambda p=pid:
                                               self.virtualProcessor(p))
            thread.start()

        super_step = 0
        real_stdout = sys.stdout
        real_stderr = sys.stderr
        while 1:
            if debug: print "*** Superstep ", super_step
            for self.current_pid in pids:
                if debug: print "*** Running ", self.current_pid
                if buffered_output:
                    sys.stdout = self.output[self.current_pid]
                    sys.stderr = self.output[self.current_pid]
                self.condition.acquire()
                self.active_thread = self.current_pid
                start_time = time.time()
                self.condition.notifyAll()
                while self.active_thread != -1:
                    self.condition.wait()
                self.condition.release()
                sys.settrace(None)
                self.runtime[self.current_pid] += time.time()-start_time
                if buffered_output:
                    sys.stdout = real_stdout
                    sys.stderr = real_stderr
                if debug: print "*** ", self.current_pid, " at barrier"
            if debug: print "*** End of superstep ", super_step
            self.messages = self.mailbox
            self.mailbox = [[] for pid in pids]
            super_step += 1
            self.done = reduce(operator.or_, self.terminated)
            if self.done:
                if debug:
                    print "*** End of execution "
                    for pid in pids:
                        if not self.terminated[pid]:
                            print "*** Processor %d still running" % pid
                break
        if self.show_runtime:
            print "Runtimes:"
            for pid in range(nproc):
                print "  Processor %d: %f s" % (pid, self.runtime[pid])

    def showtraceback(self):
        try:
            type, value, tb = sys.exc_info()
            sys.last_type = type
            sys.last_value = value
            sys.last_traceback = tb
            tblist = traceback.extract_tb(tb)
            del tblist[:1]
            list = traceback.format_list(tblist)
            if list:
                list.insert(0, "Traceback (most recent call last):\n")
            list[len(list):] = traceback.format_exception_only(type, value)
        finally:
            tblist = tb = None
        map(sys.stderr.write, list)

    def run_debug(self, command, debugger, pids):
        self.debugger = debugger
        self.debugged_pids = pids
        self.run(command, 0)
        self.debugger = None

    def put(self, object, pid_list):
        if self.debugger:
            sys.settrace(None)
        for pid in pid_list:
            self.mailbox[pid].append(copy.deepcopy(object))
            
    def send(self, messages):
        if self.debugger:
            sys.settrace(None)
        for pid, data in messages:
            self.put(data, [pid])

    def sync(self):
        if self.debugger:
            sys.settrace(None)
        pid = self.current_pid
        if debug: print "--- ", pid, " at sync"
        self.condition.acquire()
        self.active_thread = -1
        self.condition.notifyAll()
        while self.active_thread != pid:
            self.condition.wait()
        self.condition.release()
        if self.done:
            if debug: print "--- ", pid, " terminated"
            raise ForcedTermination
        if debug: print "--- ", pid, " continuing after sync"
        if self.debugger is not None:
            if self.current_pid in self.debugged_pids:
                print "*** Debugging pid", self.current_pid
                self.debugger.reset()
                sys.settrace(self.debugger.trace_dispatch)
            else:
                print "*** Not debugging pid", self.current_pid
        return self.messages[pid]

#
# Change module imports such that every virtual processor gets its
# own copy of all modules except for builtin modules.
#
class Hooks(ihooks.Hooks):

    def is_builtin(self, name):
        return imp.is_builtin(machine.baseName(name))
    def init_builtin(self, name):
        return imp.init_builtin(machine.baseName(name))
    def is_frozen(self, name):
        return imp.is_frozen(machine.baseName(name))
    def init_frozen(self, name):
        return imp.init_frozen(machine.baseName(name))
    def get_frozen_object(self, name):
        return imp.get_frozen_object(machine.baseName(name))

class ModuleLoader(ihooks.ModuleLoader):

    def find_module_in_dir(self, name, dir, allow_packages=1):
        if dir is None:
            return self.find_builtin_module(name)
        name = machine.baseName(name)
        return ihooks.ModuleLoader.find_module_in_dir(self, name, dir,
                                                      allow_packages)

class ModuleImporter(ihooks.ModuleImporter):

    def __init__(self, loader = None):
        ihooks.ModuleImporter.__init__(self, loader)
        for name in self.modules.keys():
            if not (machine.isExtendedName(name)
                    or name == "Scientific"
                    or name[:11] == "Scientific."):
                for pid in range(machine.nproc):
                    self.modules[machine.extendedName(name, pid)] = \
                                                            self.modules[name]

    def import_module(self, name, globals=None, locals=None, fromlist=None):
        name = machine.extendedName(name)
        return ihooks.ModuleImporter.import_module(self, name,
                                                   globals, locals, fromlist)

    def import_it(self, partname, fqname, parent, force_load=0):
        if not partname:
            raise ValueError, "Empty module name"
        if not force_load:
            try:
                return self.modules[fqname]
            except KeyError:
                pass
        try:
            path = parent and parent.__path__
        except AttributeError:
            return None
        stuff = self.loader.find_module(partname, path)
        if not stuff:
            return None
        m = self.loader.load_module(fqname, stuff)
        if parent:
            setattr(parent, machine.baseName(partname), m)
        return m

    def find_head_package(self, parent, name):
        if '.' in name:
            i = name.find('.')
            head = name[:i]
            tail = name[i+1:]
        else:
            head = name
            tail = ""
        if parent:
            qname = "%s.%s" % (parent.__name__, machine.baseName(head))
        else:
            qname = head
        q = self.import_it(head, qname, parent)
        if q: return q, tail
        if parent:
            qname = head
            parent = None
            q = self.import_it(head, qname, parent)
            if q: return q, tail
        raise ImportError, "No module named " + qname

#
# An interactive parallel console
#
class VirtualConsole(code.InteractiveConsole):

    def __init__(self, virtual_machine):
        code.InteractiveConsole.__init__(self)
        self.virtual_machine = virtual_machine

    def runcode(self, compiled_code):
        try:
            self.virtual_machine.run(compiled_code, 1)
            self.showOutput()
        except SystemExit:
            raise
        except:
            self.showOutput()
            self.virtual_machine.showtraceback()
        else:
            if code.softspace(sys.stdout, 0):
                print

    def showOutput(self):
        for pid in range(self.virtual_machine.numberOfProcessors()):
            output = self.virtual_machine.output[pid].getvalue()
            if output:
                sys.stdout.write(("-- Processor %d " % pid) + 40*'-' + '\n')
                sys.stdout.write(output)

    def push(self, line): 
        if line and line[0] == '!':
            command = line.strip()[1:]
            exec command in {'pm': self.pdb_pm, 'run': self.pdb_run}
            return 0
        return code.InteractiveConsole.push(self, line)

    def pdb_pm(self, pid):
        pdb.post_mortem(self.virtual_machine.exception[pid][-1])

    def pdb_run(self, pids, command):
        try:
            command = compile(command, '<string>', 'exec')
        except SyntaxError:
            self.showsyntaxerror()
            return
        if type(pids) == type(0):
            pids = [pids]
        debugger = pdb.Pdb()
        self.virtual_machine.run_debug(command, debugger, pids)
        debugger.quitting = 1
        sys.settrace(None)
        
options, args = getopt.getopt(sys.argv[1:], 'it')
interactive = 0
runtime = 0
for option, value in options:
    if option == '-i':
        interactive = 1
    if option == '-t':
        runtime = 1
if len(args) < 2:
    interactive = 1

nproc = int(args[0])
machine = VirtualBSPMachine(nproc, runtime)
sys.virtual_bsp_machine = machine

hooks = Hooks()
loader = ModuleLoader(hooks)
importer = ModuleImporter(loader)
sys.setrecursionlimit((nproc+1)*1000)
importer.install()

cprt = 'Type "copyright", "credits" or "license" for more information.'
banner = "Python %s on %s\n%s\n(Parallel console, %d processors)" % \
         (sys.version, sys.platform, cprt, nproc)

if len(args) > 1:
    script_name = args[1]
    script = file(script_name).read()
    compiled_code = compile(script, script_name, 'exec')
    machine.run(compiled_code)
    banner = ''

if interactive:
    console = VirtualConsole(machine)
    console.interact(banner)

