001    package cnslab.cnsnetwork;
002    
003    import java.io.ByteArrayInputStream;
004    import java.io.FileOutputStream;
005    import java.io.PrintStream;
006    import java.util.Iterator;
007    import java.util.Stack;
008    import java.util.concurrent.BrokenBarrierException;
009    import java.util.concurrent.CyclicBarrier;
010    import org.w3c.dom.Document;
011    import org.w3c.dom.Node;
012    import org.w3c.dom.NodeList;
013    import org.w3c.dom.ls.DOMImplementationLS;
014    import org.w3c.dom.ls.LSInput;
015    import org.w3c.dom.ls.LSParser;
016    import org.w3c.dom.bootstrap.DOMImplementationRegistry;
017    
018    import jpvm.*;
019    
020    import cnslab.cnsmath.*;
021
022// TODO:  Refactor in the same way NetHost was refactored.
023
024    /***********************************************************************
025    * Same as NetHost but deals with avalanche
026    * 
027    * @version
028    *   $Date: 2012-08-04 13:43:22 -0500 (Sat, 04 Aug 2012) $
029    *   $Rev: 104 $
030    *   $Author: croft $
031    * @author
032    *   Yi Dong
033    * @author
034    *   David Wallace Croft
035    ***********************************************************************/
036    public final class  ANetHost
037    ////////////////////////////////////////////////////////////////////////
038    ////////////////////////////////////////////////////////////////////////
039    {
040  
041    public static JpvmInfo  info;
042
043    public static void main(String args[])
044    ////////////////////////////////////////////////////////////////////////
045    {
046      try {
047    double  minDelay;
048    double backFire;
049    int seedInt;
050
051    info= new JpvmInfo();
052    // Enroll in the parallel virtual machine...
053    info.jpvm = new jpvmEnvironment();
054
055    info.myJpvmTaskId = info.jpvm.pvm_mytid();
056
057
058    // Get my parent's task id...
059//    info.parent = info.jpvm.pvm_parent(); //actually it is grand parent;
060
061    //recevive infomation about it's peers
062    jpvmMessage m = info.jpvm.pvm_recv(NetMessageTag.sendTids);
063
064    info.numTasks =m.buffer.upkint();
065
066    info.tids = new jpvmTaskId[info.numTasks];
067    info.endIndex = new int[info.numTasks];
068    m.buffer.unpack(info.tids,info.numTasks,1);
069    m.buffer.unpack(info.endIndex,info.numTasks,1);
070    seedInt = m.buffer.upkint();
071
072    int iter_int;
073    for(iter_int=0; iter_int< info.numTasks; iter_int++)
074    {
075      if(info.myJpvmTaskId.equals(info.tids[iter_int])){break;}
076    }
077
078    info.idIndex=iter_int;
079
080    info.parentJpvmTaskId = m.buffer.upktid(); //trialHost.
081
082
083    info.tids[info.idIndex]= info.jpvm.pvm_parent();  //self id should not be stored, change to grandpa's id
084    //NetHosts tid
085
086
087    //FileOutputStream outt = new FileOutputStream("log/"+info.parent.getHost()+"_preinfo"+iter_int+".txt");
088    //PrintStream p = new  PrintStream(outt);
089
090
091//    String out="Host id "+info.idIndex+"\n";
092
093    byte [] ba;
094    int baLength;
095
096    baLength = m.buffer.upkint();
097    ba = new byte[baLength];
098    m.buffer.unpack(ba,baLength,1);
099
100    DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
101    DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
102
103    LSInput input = impl.createLSInput();
104    input.setByteStream(new ByteArrayInputStream(ba));
105    LSParser parser = impl.createLSParser(DOMImplementationLS.MODE_SYNCHRONOUS,null);
106
107
108    SimulatorParser pas = new SimulatorParser(new Seed(seedInt-info.idIndex), parser.parse(input));
109
110    pas.parseMapCells(info.idIndex);
111
112//    p.println("cell maped");
113//    System.out.println(pas.ls.base+ " "+pas.ls.neuron_end);
114  
115    pas.parseNeuronDef();
116
117    pas.parsePopNeurons(info.idIndex);
118//    p.println("cell poped");
119    pas.parseScaffold(info.idIndex);
120    pas.layerStructure.buildStructure(info.idIndex);
121    pas.parseConnection(info.idIndex);
122
123    pas.parseTarget(info);
124
125
126//    pas.ls.sortSynapses(info.idIndex); //sort the  synapses, not necessary unless tune the parameters
127//    p.println("connected");
128    pas.parseExp(info.idIndex);
129//    p.println("exp");
130                pas.findMinDelay();
131//    p.close();
132
133    CyclicBarrier barrier = new CyclicBarrier(2);
134
135    /*
136    //receiving neurons
137    m = info.jpvm.pvm_recv(NetMessageTag.sendNeurons);
138    int numNeurons;
139    numNeurons = m.buffer.upkint();
140    int base = m.buffer.upkint();
141    Neuron neurons[] = new Neuron[numNeurons+1]; // the last neuron is a sync neuron;
142    m.buffer.unpack(neurons,numNeurons,1);
143    */
144
145
146
147    /*
148    out=out+"\n";
149    for(int i=0; i< numNeurons; i++)
150    {
151      out=out+"Neurons are:"+neurons[i].toString();
152    }
153    */
154
155    Seed idum= new Seed(seedInt-info.idIndex);
156
157
158    if(pas.layerStructure.axons == null) throw new RuntimeException("no axon info");
159
160    // neurons,pvminfo,base index, mini delay, background freq, seed number
161    
162    final Network testNet = new ANetwork (
163      pas.getModelFactory ( ),
164      pas.getDiscreteEventQueue ( ),
165      pas.getModulatedSynapseSeq ( ),
166      pas.layerStructure.neurons,
167      pas.layerStructure.axons,
168      info,
169      pas.layerStructure.base,
170      pas.minDelay,
171      pas,
172      idum,
173      pas.experiment );
174    
175    testNet.initNet();
176
177//    pas.p = testNet.p;
178
179
180//    if(info.idIndex==0) pas.ls.connectFrom("O,0,0,E", testNet.p);
181//    if(info.idIndex==4) pas.ls.connectFrom("O,0,16,0,E", testNet.p);
182  //  pas.ls.connectFrom("T,27,9,L", testNet.p);
183
184    testNet.p.println("Host id "+info.idIndex+"\n");
185    testNet.p.println("base "+testNet.base+"\n");
186    testNet.p.println("idum "+idum.seed+"\n");
187    testNet.p.println("NumOfSyn "+pas.layerStructure.numSYN+"\n");
188
189
190//    testNet.p.flush();
191
192    Object lock = new Object();
193    Object synLock = new Object();
194//    ListenInput listen = new ListenInput(testNet,lock);
195//                PCommunicationThread p1 = new PCommunicationThread(testNet, lock);
196 //               PComputationThread p2 = new PComputationThread(testNet, lock);
197    
198               PRun p1 = new PRun (
199                 testNet.getDiscreteEventQueue ( ),
200                 testNet,
201                 lock,
202                 barrier );
203
204   //    pas.ls.cellmap=null;
205         System.gc();
206
207              Thread run = new Thread(p1);
208        run.setDaemon(true);
209        run.start();
210
211//Barrier Sync
212    jpvmBuffer buf = new jpvmBuffer();
213    buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info;
214    info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
215//    m = info.jpvm.pvm_recv(NetMessageTag.readySig);
216//Barrier Sync
217
218        while (!testNet.stop)
219        {
220          m = testNet.info.jpvm.pvm_recv(); //receive info from others
221          if(m.messageTag==NetMessageTag.trialDone)testNet.trialDone=false;
222//          synchronized(lock)
223          {
224//            testNet.p.println("message "+m.messageTag);
225            // lock.notify();
226            switch(m.messageTag)
227            {
228              case NetMessageTag.sendSpike:
229                synchronized (lock)
230                {
231
232                  int sourceID = m.buffer.upkint();
233                  int trialID = m.buffer.upkint();
234                  if(trialID==testNet.countTrial)
235                  {
236                    //                    testNet.p.println("received and processed");
237                    testNet.spikeState=true;
238                    (testNet.received[sourceID])++;
239                    for(int iter=0;iter<info.numTasks;iter++)
240                    {
241                      if(iter!=info.idIndex && testNet.received[iter]==0) testNet.spikeState=false;
242                    }
243                    SpikeBuffer sbuff = (SpikeBuffer)m.buffer.upkcnsobj();
244                    Iterator<NetMessage> iter = sbuff.buff.iterator();
245                    while(iter.hasNext())
246                    {
247                      NetMessage message = iter.next();
248//                      try {
249                        for(int ii =0; ii <testNet.axons.get(message.from).branches.length; ii++)
250                        { 
251                          // new input events
252                          
253                          testNet.getInputEventSlot ( ).offer (
254                            new AInputEvent (
255                              message.time
256                                + testNet.axons.get ( message.from )
257                                .branches [ ii ].delay,
258                              testNet.axons.get ( message.from )
259                                .branches [ ii ],
260                              message.from,
261                              ( ( ANetMessage ) message ).sourceId,
262                              ( ( ANetMessage ) message ).avalancheId ) );
263                        }
264//                      }
265//                      catch(Exception ex) {
266//                        throw new RuntimeException(ex.getMessage()+"\n from:"+message.from+" host id:"+info.idIndex+" axons"+testNet.axons.size());
267//                      }
268                    }
269                    lock.notify();
270                  }
271                  /*
272               else
273               {
274                  //                    testNet.p.println("received and ignored");
275               }
276               */
277                }
278                break;
279              case NetMessageTag.syncRoot: //if its a message about time
280                synchronized (lock)
281                {
282
283                  if(!testNet.trialDone)
284                  {
285
286                    testNet.rootTime=m.buffer.upkdouble();
287                    lock.notify();
288                  }
289                }
290                break;
291              case NetMessageTag.stopSig: //if its a message about time
292                testNet.p.println("get stop message root "+testNet.rootTime);
293                testNet.p.flush();
294                buf = new jpvmBuffer();
295                buf.pack(testNet.recorderData);
296                //buf.pack(info.idIndex);
297                buf.pack(((ANetwork) testNet).aData);
298                testNet.p.println("Size of couting "+ ((ANetwork) testNet).aData.avalancheCounter.size());
299                testNet.p.flush();
300
301                info.jpvm.pvm_send(buf, info.tids[info.idIndex], NetMessageTag.getBackData);
302                
303                testNet.clearQueues ( );
304
305                testNet.stop=true;
306                testNet.p.close();
307                synchronized (lock)
308                {
309                  lock.notify();
310                }
311                break;
312              case NetMessageTag.tempStopSig: //if its a message about time
313                buf = new jpvmBuffer();
314                buf.pack(testNet.recorderData);
315                buf.pack(info.idIndex);
316                info.jpvm.pvm_send(buf, info.tids[info.idIndex], NetMessageTag.getBackData);
317                testNet.recorderData.clear();
318                break;
319              case NetMessageTag.trialDone: // new trial begins
320                //                        testNet.p.println("begin nofiying");
321                //
322                synchronized (lock)
323                {
324                  testNet.trialId = m.buffer.upkint();
325                  testNet.subExpId = m.buffer.upkint();
326                  testNet.endOfTrial = testNet.experiment.subExp[testNet.subExpId].trialLength;
327                  Thread.yield();
328                  Thread.yield();
329                  barrier.await();
330//                  synchronized (synLock)
331//                  {
332//                    synLock.notify();
333          //          testNet.startSig=true;
334//                  }
335                  lock.notify();
336                }
337                break;
338              case NetMessageTag.changeConnection:
339                seedInt = m.buffer.upkint();
340                baLength = m.buffer.upkint();
341                ba = new byte[baLength];
342                m.buffer.unpack(ba,baLength,1);
343
344                // DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
345                // DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
346                input = impl.createLSInput();
347                input.setByteStream(new ByteArrayInputStream(ba));
348                parser = impl.createLSParser(DOMImplementationLS.MODE_SYNCHRONOUS,null);
349
350                NodeList conns = pas.rootElement.getElementsByTagName("Connections");
351                testNet.p.println("connection num"+conns.getLength());
352                pas.rootElement.removeChild(conns.item(0));
353                Node dup = pas.document.importNode(parser.parse(input).getDocumentElement().getElementsByTagName("Connections").item(0) , true);
354                pas.rootElement.appendChild(dup);
355                double weight;
356                weight = pas.parseChangeConnection(info.idIndex); //change the connections;
357                //synchroniz
358                //                testNet.p.println("connection change done with weight "+weight);
359                testNet.seed = new Seed(seedInt-info.idIndex);
360                //                testNet.p.println("new seed"+testNet.idum.seed);
361
362                buf = new jpvmBuffer();
363                if(weight<0)
364                {
365                  buf.pack("NetHost "+info.jpvm.pvm_mytid().getHost()+" has been changed"); //send out ready info;
366                }
367                else
368                {
369                  buf.pack("badweight$"+weight); //send out ready info;
370                }
371
372                info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
373                //end of sync
374                if(weight<0)
375                {
376                  buf = new jpvmBuffer();
377                  info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone);
378                }
379
380                break;
381              case NetMessageTag.netHostNotify:
382                //                testNet.p.println("nethost notify");
383                buf = new jpvmBuffer();
384                buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info;
385                info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
386
387                //                 testNet.idum.seed = testNet.saveSeed; //restore saved seed for comparison;
388                buf = new jpvmBuffer();
389                info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone);
390                break;
391              case NetMessageTag.resetNetHost:
392                synchronized(lock)
393                {
394                  //break;
395                  testNet.trialDone=true;
396
397                  //added;
398                  testNet.spikeState=true;
399                  for(int iter=0;iter<testNet.info.numTasks;iter++)
400                  {       
401                    /*
402                    while(!testNet.received[iter].empty())
403                    {
404                      testNet.received[iter].pop();
405                    }
406                    */
407                    testNet.received[iter]=1; //leave mark here
408                  }
409                  //added over;
410
411                  //initilization
412                  //                testNet.rootTime=0.0;
413                  
414                  testNet.recorderData.clear();
415                  
416                  testNet.clearQueues ( );
417                  
418                  testNet.countTrial++;
419                  
420                  //                testNet.p.println("reset the trial now");
421
422                  testNet.recordBuff.buff.clear();
423                  testNet.intraRecBuffers.init();
424
425                  for(int i=0; i<info.numTasks;i++)
426                  {
427                    testNet.spikeBuffers[i].buff.clear();
428                  }
429                  //                testNet.fireQueue.insertItem( new FireEvent(testNet.neurons.length-1+testNet.base,testNet.endOfTrial-testNet.minDelay/2.0));
430
431                  //                for(int i=0; i<testNet.neurons.length-1; i++)
432                  //                {
433                  //                  if(testNet.neurons[i].isSensory())
434                  //                  {
435                  // fireQueue.insertItem( new FireEvent(i+base, neurons[i].updateFire() )); //sensory neuron send spikes to the other neurons
436                  //                  }
437                  //                  else
438                  //                  {
439                  //                    testNet.neurons[i].init(idum); //nonsensory neurons will be initiazed;
440                  //                  }
441                  //                }
442
443                  lock.notify();
444                  while( info.jpvm.pvm_probe())
445                  {
446                    m = testNet.info.jpvm.pvm_recv(); //clear buffer
447                  }      
448
449                  buf = new jpvmBuffer();
450                  buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info;
451                  info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig);
452
453                  //                 testNet.idum.seed = testNet.saveSeed; //restore saved seed for comparison;
454
455                  buf = new jpvmBuffer();
456                  info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone);
457                  //                testNet.p.println("reset done");
458                }
459                break;
460              case NetMessageTag.checkTime:
461//                testNet.p.println("check time now");
462//                testNet.p.flush();
463                buf = new jpvmBuffer();
464                buf.pack(m.buffer.upkint()); //send the availale Host id
465                buf.pack(testNet.subExpId);
466                buf.pack(p1.minTime);
467                buf.pack(testNet.trialId);
468                info.jpvm.pvm_send(buf,info.tids[info.idIndex],NetMessageTag.checkTime);
469                break;
470            }
471          }
472        }
473
474      info.jpvm.pvm_exit();
475
476
477      } 
478      catch (jpvmException jpe) {
479    System.out.println("Error - jpvm exception");
480    try {
481      FileOutputStream out = new FileOutputStream("log/"+info.jpvm.pvm_mytid().getHost()+"error.txt");
482      PrintStream p = new  PrintStream(out);
483      jpe.printStackTrace(p);
484      p.close();
485    }
486    catch(Exception ex) {
487    }
488      }
489      catch (Exception a)
490      {
491        try {
492          FileOutputStream out = new FileOutputStream("log/"+info.jpvm.pvm_mytid().getHost()+"error.txt");
493          PrintStream p = new  PrintStream(out);
494          a.printStackTrace(p);
495          p.close();
496        }
497        catch(Exception ex) {
498        }
499      }
500  }
501};