001 package cnslab.cnsnetwork; 002 003 import java.io.ByteArrayInputStream; 004 import java.io.FileOutputStream; 005 import java.io.PrintStream; 006 import java.util.Iterator; 007 import java.util.Stack; 008 import java.util.concurrent.BrokenBarrierException; 009 import java.util.concurrent.CyclicBarrier; 010 import org.w3c.dom.Document; 011 import org.w3c.dom.Node; 012 import org.w3c.dom.NodeList; 013 import org.w3c.dom.ls.DOMImplementationLS; 014 import org.w3c.dom.ls.LSInput; 015 import org.w3c.dom.ls.LSParser; 016 import org.w3c.dom.bootstrap.DOMImplementationRegistry; 017 018 import jpvm.*; 019 020 import cnslab.cnsmath.*; 021 022// TODO: Refactor in the same way NetHost was refactored. 023 024 /*********************************************************************** 025 * Same as NetHost but deals with avalanche 026 * 027 * @version 028 * $Date: 2012-08-04 13:43:22 -0500 (Sat, 04 Aug 2012) $ 029 * $Rev: 104 $ 030 * $Author: croft $ 031 * @author 032 * Yi Dong 033 * @author 034 * David Wallace Croft 035 ***********************************************************************/ 036 public final class ANetHost 037 //////////////////////////////////////////////////////////////////////// 038 //////////////////////////////////////////////////////////////////////// 039 { 040 041 public static JpvmInfo info; 042 043 public static void main(String args[]) 044 //////////////////////////////////////////////////////////////////////// 045 { 046 try { 047 double minDelay; 048 double backFire; 049 int seedInt; 050 051 info= new JpvmInfo(); 052 // Enroll in the parallel virtual machine... 053 info.jpvm = new jpvmEnvironment(); 054 055 info.myJpvmTaskId = info.jpvm.pvm_mytid(); 056 057 058 // Get my parent's task id... 059// info.parent = info.jpvm.pvm_parent(); //actually it is grand parent; 060 061 //recevive infomation about it's peers 062 jpvmMessage m = info.jpvm.pvm_recv(NetMessageTag.sendTids); 063 064 info.numTasks =m.buffer.upkint(); 065 066 info.tids = new jpvmTaskId[info.numTasks]; 067 info.endIndex = new int[info.numTasks]; 068 m.buffer.unpack(info.tids,info.numTasks,1); 069 m.buffer.unpack(info.endIndex,info.numTasks,1); 070 seedInt = m.buffer.upkint(); 071 072 int iter_int; 073 for(iter_int=0; iter_int< info.numTasks; iter_int++) 074 { 075 if(info.myJpvmTaskId.equals(info.tids[iter_int])){break;} 076 } 077 078 info.idIndex=iter_int; 079 080 info.parentJpvmTaskId = m.buffer.upktid(); //trialHost. 081 082 083 info.tids[info.idIndex]= info.jpvm.pvm_parent(); //self id should not be stored, change to grandpa's id 084 //NetHosts tid 085 086 087 //FileOutputStream outt = new FileOutputStream("log/"+info.parent.getHost()+"_preinfo"+iter_int+".txt"); 088 //PrintStream p = new PrintStream(outt); 089 090 091// String out="Host id "+info.idIndex+"\n"; 092 093 byte [] ba; 094 int baLength; 095 096 baLength = m.buffer.upkint(); 097 ba = new byte[baLength]; 098 m.buffer.unpack(ba,baLength,1); 099 100 DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance(); 101 DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS"); 102 103 LSInput input = impl.createLSInput(); 104 input.setByteStream(new ByteArrayInputStream(ba)); 105 LSParser parser = impl.createLSParser(DOMImplementationLS.MODE_SYNCHRONOUS,null); 106 107 108 SimulatorParser pas = new SimulatorParser(new Seed(seedInt-info.idIndex), parser.parse(input)); 109 110 pas.parseMapCells(info.idIndex); 111 112// p.println("cell maped"); 113// System.out.println(pas.ls.base+ " "+pas.ls.neuron_end); 114 115 pas.parseNeuronDef(); 116 117 pas.parsePopNeurons(info.idIndex); 118// p.println("cell poped"); 119 pas.parseScaffold(info.idIndex); 120 pas.layerStructure.buildStructure(info.idIndex); 121 pas.parseConnection(info.idIndex); 122 123 pas.parseTarget(info); 124 125 126// pas.ls.sortSynapses(info.idIndex); //sort the synapses, not necessary unless tune the parameters 127// p.println("connected"); 128 pas.parseExp(info.idIndex); 129// p.println("exp"); 130 pas.findMinDelay(); 131// p.close(); 132 133 CyclicBarrier barrier = new CyclicBarrier(2); 134 135 /* 136 //receiving neurons 137 m = info.jpvm.pvm_recv(NetMessageTag.sendNeurons); 138 int numNeurons; 139 numNeurons = m.buffer.upkint(); 140 int base = m.buffer.upkint(); 141 Neuron neurons[] = new Neuron[numNeurons+1]; // the last neuron is a sync neuron; 142 m.buffer.unpack(neurons,numNeurons,1); 143 */ 144 145 146 147 /* 148 out=out+"\n"; 149 for(int i=0; i< numNeurons; i++) 150 { 151 out=out+"Neurons are:"+neurons[i].toString(); 152 } 153 */ 154 155 Seed idum= new Seed(seedInt-info.idIndex); 156 157 158 if(pas.layerStructure.axons == null) throw new RuntimeException("no axon info"); 159 160 // neurons,pvminfo,base index, mini delay, background freq, seed number 161 162 final Network testNet = new ANetwork ( 163 pas.getModelFactory ( ), 164 pas.getDiscreteEventQueue ( ), 165 pas.getModulatedSynapseSeq ( ), 166 pas.layerStructure.neurons, 167 pas.layerStructure.axons, 168 info, 169 pas.layerStructure.base, 170 pas.minDelay, 171 pas, 172 idum, 173 pas.experiment ); 174 175 testNet.initNet(); 176 177// pas.p = testNet.p; 178 179 180// if(info.idIndex==0) pas.ls.connectFrom("O,0,0,E", testNet.p); 181// if(info.idIndex==4) pas.ls.connectFrom("O,0,16,0,E", testNet.p); 182 // pas.ls.connectFrom("T,27,9,L", testNet.p); 183 184 testNet.p.println("Host id "+info.idIndex+"\n"); 185 testNet.p.println("base "+testNet.base+"\n"); 186 testNet.p.println("idum "+idum.seed+"\n"); 187 testNet.p.println("NumOfSyn "+pas.layerStructure.numSYN+"\n"); 188 189 190// testNet.p.flush(); 191 192 Object lock = new Object(); 193 Object synLock = new Object(); 194// ListenInput listen = new ListenInput(testNet,lock); 195// PCommunicationThread p1 = new PCommunicationThread(testNet, lock); 196 // PComputationThread p2 = new PComputationThread(testNet, lock); 197 198 PRun p1 = new PRun ( 199 testNet.getDiscreteEventQueue ( ), 200 testNet, 201 lock, 202 barrier ); 203 204 // pas.ls.cellmap=null; 205 System.gc(); 206 207 Thread run = new Thread(p1); 208 run.setDaemon(true); 209 run.start(); 210 211//Barrier Sync 212 jpvmBuffer buf = new jpvmBuffer(); 213 buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info; 214 info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig); 215// m = info.jpvm.pvm_recv(NetMessageTag.readySig); 216//Barrier Sync 217 218 while (!testNet.stop) 219 { 220 m = testNet.info.jpvm.pvm_recv(); //receive info from others 221 if(m.messageTag==NetMessageTag.trialDone)testNet.trialDone=false; 222// synchronized(lock) 223 { 224// testNet.p.println("message "+m.messageTag); 225 // lock.notify(); 226 switch(m.messageTag) 227 { 228 case NetMessageTag.sendSpike: 229 synchronized (lock) 230 { 231 232 int sourceID = m.buffer.upkint(); 233 int trialID = m.buffer.upkint(); 234 if(trialID==testNet.countTrial) 235 { 236 // testNet.p.println("received and processed"); 237 testNet.spikeState=true; 238 (testNet.received[sourceID])++; 239 for(int iter=0;iter<info.numTasks;iter++) 240 { 241 if(iter!=info.idIndex && testNet.received[iter]==0) testNet.spikeState=false; 242 } 243 SpikeBuffer sbuff = (SpikeBuffer)m.buffer.upkcnsobj(); 244 Iterator<NetMessage> iter = sbuff.buff.iterator(); 245 while(iter.hasNext()) 246 { 247 NetMessage message = iter.next(); 248// try { 249 for(int ii =0; ii <testNet.axons.get(message.from).branches.length; ii++) 250 { 251 // new input events 252 253 testNet.getInputEventSlot ( ).offer ( 254 new AInputEvent ( 255 message.time 256 + testNet.axons.get ( message.from ) 257 .branches [ ii ].delay, 258 testNet.axons.get ( message.from ) 259 .branches [ ii ], 260 message.from, 261 ( ( ANetMessage ) message ).sourceId, 262 ( ( ANetMessage ) message ).avalancheId ) ); 263 } 264// } 265// catch(Exception ex) { 266// throw new RuntimeException(ex.getMessage()+"\n from:"+message.from+" host id:"+info.idIndex+" axons"+testNet.axons.size()); 267// } 268 } 269 lock.notify(); 270 } 271 /* 272 else 273 { 274 // testNet.p.println("received and ignored"); 275 } 276 */ 277 } 278 break; 279 case NetMessageTag.syncRoot: //if its a message about time 280 synchronized (lock) 281 { 282 283 if(!testNet.trialDone) 284 { 285 286 testNet.rootTime=m.buffer.upkdouble(); 287 lock.notify(); 288 } 289 } 290 break; 291 case NetMessageTag.stopSig: //if its a message about time 292 testNet.p.println("get stop message root "+testNet.rootTime); 293 testNet.p.flush(); 294 buf = new jpvmBuffer(); 295 buf.pack(testNet.recorderData); 296 //buf.pack(info.idIndex); 297 buf.pack(((ANetwork) testNet).aData); 298 testNet.p.println("Size of couting "+ ((ANetwork) testNet).aData.avalancheCounter.size()); 299 testNet.p.flush(); 300 301 info.jpvm.pvm_send(buf, info.tids[info.idIndex], NetMessageTag.getBackData); 302 303 testNet.clearQueues ( ); 304 305 testNet.stop=true; 306 testNet.p.close(); 307 synchronized (lock) 308 { 309 lock.notify(); 310 } 311 break; 312 case NetMessageTag.tempStopSig: //if its a message about time 313 buf = new jpvmBuffer(); 314 buf.pack(testNet.recorderData); 315 buf.pack(info.idIndex); 316 info.jpvm.pvm_send(buf, info.tids[info.idIndex], NetMessageTag.getBackData); 317 testNet.recorderData.clear(); 318 break; 319 case NetMessageTag.trialDone: // new trial begins 320 // testNet.p.println("begin nofiying"); 321 // 322 synchronized (lock) 323 { 324 testNet.trialId = m.buffer.upkint(); 325 testNet.subExpId = m.buffer.upkint(); 326 testNet.endOfTrial = testNet.experiment.subExp[testNet.subExpId].trialLength; 327 Thread.yield(); 328 Thread.yield(); 329 barrier.await(); 330// synchronized (synLock) 331// { 332// synLock.notify(); 333 // testNet.startSig=true; 334// } 335 lock.notify(); 336 } 337 break; 338 case NetMessageTag.changeConnection: 339 seedInt = m.buffer.upkint(); 340 baLength = m.buffer.upkint(); 341 ba = new byte[baLength]; 342 m.buffer.unpack(ba,baLength,1); 343 344 // DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance(); 345 // DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS"); 346 input = impl.createLSInput(); 347 input.setByteStream(new ByteArrayInputStream(ba)); 348 parser = impl.createLSParser(DOMImplementationLS.MODE_SYNCHRONOUS,null); 349 350 NodeList conns = pas.rootElement.getElementsByTagName("Connections"); 351 testNet.p.println("connection num"+conns.getLength()); 352 pas.rootElement.removeChild(conns.item(0)); 353 Node dup = pas.document.importNode(parser.parse(input).getDocumentElement().getElementsByTagName("Connections").item(0) , true); 354 pas.rootElement.appendChild(dup); 355 double weight; 356 weight = pas.parseChangeConnection(info.idIndex); //change the connections; 357 //synchroniz 358 // testNet.p.println("connection change done with weight "+weight); 359 testNet.seed = new Seed(seedInt-info.idIndex); 360 // testNet.p.println("new seed"+testNet.idum.seed); 361 362 buf = new jpvmBuffer(); 363 if(weight<0) 364 { 365 buf.pack("NetHost "+info.jpvm.pvm_mytid().getHost()+" has been changed"); //send out ready info; 366 } 367 else 368 { 369 buf.pack("badweight$"+weight); //send out ready info; 370 } 371 372 info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig); 373 //end of sync 374 if(weight<0) 375 { 376 buf = new jpvmBuffer(); 377 info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone); 378 } 379 380 break; 381 case NetMessageTag.netHostNotify: 382 // testNet.p.println("nethost notify"); 383 buf = new jpvmBuffer(); 384 buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info; 385 info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig); 386 387 // testNet.idum.seed = testNet.saveSeed; //restore saved seed for comparison; 388 buf = new jpvmBuffer(); 389 info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone); 390 break; 391 case NetMessageTag.resetNetHost: 392 synchronized(lock) 393 { 394 //break; 395 testNet.trialDone=true; 396 397 //added; 398 testNet.spikeState=true; 399 for(int iter=0;iter<testNet.info.numTasks;iter++) 400 { 401 /* 402 while(!testNet.received[iter].empty()) 403 { 404 testNet.received[iter].pop(); 405 } 406 */ 407 testNet.received[iter]=1; //leave mark here 408 } 409 //added over; 410 411 //initilization 412 // testNet.rootTime=0.0; 413 414 testNet.recorderData.clear(); 415 416 testNet.clearQueues ( ); 417 418 testNet.countTrial++; 419 420 // testNet.p.println("reset the trial now"); 421 422 testNet.recordBuff.buff.clear(); 423 testNet.intraRecBuffers.init(); 424 425 for(int i=0; i<info.numTasks;i++) 426 { 427 testNet.spikeBuffers[i].buff.clear(); 428 } 429 // testNet.fireQueue.insertItem( new FireEvent(testNet.neurons.length-1+testNet.base,testNet.endOfTrial-testNet.minDelay/2.0)); 430 431 // for(int i=0; i<testNet.neurons.length-1; i++) 432 // { 433 // if(testNet.neurons[i].isSensory()) 434 // { 435 // fireQueue.insertItem( new FireEvent(i+base, neurons[i].updateFire() )); //sensory neuron send spikes to the other neurons 436 // } 437 // else 438 // { 439 // testNet.neurons[i].init(idum); //nonsensory neurons will be initiazed; 440 // } 441 // } 442 443 lock.notify(); 444 while( info.jpvm.pvm_probe()) 445 { 446 m = testNet.info.jpvm.pvm_recv(); //clear buffer 447 } 448 449 buf = new jpvmBuffer(); 450 buf.pack("NetHost "+info.jpvm.pvm_mytid().toString()+" is ready to go"); //send out ready info; 451 info.jpvm.pvm_send(buf, info.tids[info.idIndex] ,NetMessageTag.readySig); 452 453 // testNet.idum.seed = testNet.saveSeed; //restore saved seed for comparison; 454 455 buf = new jpvmBuffer(); 456 info.jpvm.pvm_send(buf, info.parentJpvmTaskId, NetMessageTag.trialDone); 457 // testNet.p.println("reset done"); 458 } 459 break; 460 case NetMessageTag.checkTime: 461// testNet.p.println("check time now"); 462// testNet.p.flush(); 463 buf = new jpvmBuffer(); 464 buf.pack(m.buffer.upkint()); //send the availale Host id 465 buf.pack(testNet.subExpId); 466 buf.pack(p1.minTime); 467 buf.pack(testNet.trialId); 468 info.jpvm.pvm_send(buf,info.tids[info.idIndex],NetMessageTag.checkTime); 469 break; 470 } 471 } 472 } 473 474 info.jpvm.pvm_exit(); 475 476 477 } 478 catch (jpvmException jpe) { 479 System.out.println("Error - jpvm exception"); 480 try { 481 FileOutputStream out = new FileOutputStream("log/"+info.jpvm.pvm_mytid().getHost()+"error.txt"); 482 PrintStream p = new PrintStream(out); 483 jpe.printStackTrace(p); 484 p.close(); 485 } 486 catch(Exception ex) { 487 } 488 } 489 catch (Exception a) 490 { 491 try { 492 FileOutputStream out = new FileOutputStream("log/"+info.jpvm.pvm_mytid().getHost()+"error.txt"); 493 PrintStream p = new PrintStream(out); 494 a.printStackTrace(p); 495 p.close(); 496 } 497 catch(Exception ex) { 498 } 499 } 500 } 501};