001 package cnslab.cnsnetwork; 002 003 import org.w3c.dom.Document; 004 import org.w3c.dom.Element; 005 import org.w3c.dom.NodeList; 006 import cnslab.cnsmath.Seed; 007 import org.w3c.dom.Node; 008 import org.w3c.dom.bootstrap.DOMImplementationRegistry; 009 import org.w3c.dom.ls.DOMImplementationLS; 010 import org.w3c.dom.ls.LSSerializer; 011 import org.w3c.dom.ls.LSOutput; 012 import java.io.FileOutputStream; 013 import java.io.File; 014 import java.io.ByteArrayOutputStream; 015 import java.io.ByteArrayInputStream; 016 import java.io.InputStream; 017 import org.w3c.dom.DocumentType; 018 import jpvm.jpvmTaskId; 019 import cnslab.cnsnetwork.JpvmInfo; 020 import java.util.LinkedList; 021 import java.util.HashMap; 022 import java.util.Set; 023 import java.util.Map; 024 import jpvm.jpvmEnvironment; 025 import jpvm.jpvmBuffer; 026 import jpvm.jpvmMessage; 027 import jpvm.jpvmException; 028 import java.util.Iterator; 029 import java.util.Timer; 030 import java.util.TimerTask; 031 032 /*********************************************************************** 033 * Deprecated. 034 * 035 * @version 036 * $Date: 2012-08-04 13:43:22 -0500 (Sat, 04 Aug 2012) $ 037 * $Rev: 104 $ 038 * $Author: croft $ 039 * @author 040 * Yi Dong 041 * @author 042 * David Wallace Croft 043 ***********************************************************************/ 044@Deprecated 045public class FitFun 046{ 047 public FitFun(ParToDoc parDoc, Seed idum, String modelFilename, int heapSize ) 048 { 049 try { 050 this.parDoc = parDoc; 051 this.seedInt = idum.seed; 052 this.saveInt = idum.seed; 053 this.idum = idum; 054 055 //spawn all the slaves and ready for the simulation 056 pas = new SimulatorParser ( 057 idum, new File ( "model/" + modelFilename ) ); 058 059 pas.parseMapCells(); 060 pas.parseExperiment(); 061 exp = pas.experiment; // experiment infomation 062 pas.findMinDelay(); 063 064 pas.document.removeChild(pas.documentType); 065 DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance(); 066 DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS"); 067 LSSerializer writer = impl.createLSSerializer(); 068 LSOutput output = impl.createLSOutput(); 069 ByteArrayOutputStream bArray = new ByteArrayOutputStream(); 070 output.setByteStream(bArray); 071 writer.write(pas.document, output); 072 pas.document.appendChild(pas.documentType); 073 byte [] ba = bArray.toByteArray(); 074 075 076 info= new JpvmInfo(); 077 // Enroll in the parallel virtual machine... 078 info.jpvm = new jpvmEnvironment(); 079 080 // Get my task id... 081 info.myJpvmTaskId = info.jpvm.pvm_mytid(); 082 System.out.println("Task Id: "+info.myJpvmTaskId.toString()); 083 084 info.numTasks= pas.parallelHost; // total number of trial hosts; 085 info.idIndex=info.numTasks; // root id; //not used at all 086 info.tids = new jpvmTaskId[info.numTasks]; 087 088 // Spawn some trialHosts 089 info.jpvm.pvm_spawn("cnslab.cnsnetwork.TrialHost",info.numTasks,info.tids,48); 090 System.out.println("spawn successfully"); 091 092 jpvmBuffer buf2 = new jpvmBuffer(); 093 094 buf2.pack(info.numTasks); 095 buf2.pack(info.tids, info.numTasks, 1); 096 buf2.pack(pas.minDelay); 097 098 info.jpvm.pvm_mcast(buf2,info.tids,info.numTasks,NetMessageTag.sendTids); 099 100 System.out.println("All sent"); 101 102 //gernerate all the nethosts : 103 info.endIndex=pas.layerStructure.nodeEndIndices; 104 105 netTids = new jpvmTaskId [info.numTasks][info.endIndex.length] ; 106 107 for(int i = 0 ; i < info.numTasks; i++) 108 { 109 System.out.println("generate child for trialHost "+i); 110 info.jpvm.pvm_spawn("cnslab.cnsnetwork.NetHostTune",info.endIndex.length,netTids[i],heapSize); //Net Host is to seperate large network into small pieces; 111 jpvmBuffer buf = new jpvmBuffer(); 112 buf.pack(info.endIndex.length); 113 buf.pack(netTids[i],info.endIndex.length,1); 114 buf.pack(info.endIndex,info.endIndex.length,1); 115 seedInt = seedInt - info.endIndex.length; 116 buf.pack(seedInt); 117 buf.pack(info.tids[i]); //parent's tid; 118 buf.pack(ba.length); 119 buf.pack(ba, ba.length, 1); 120 info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.sendTids); 121 } 122 123 for(int i = 0 ; i < info.numTasks; i++) 124 { 125 jpvmBuffer buf = new jpvmBuffer(); 126 buf.pack(info.endIndex.length); 127 buf.pack(netTids[i],info.endIndex.length,1); 128 info.jpvm.pvm_send(buf,info.tids[i],NetMessageTag.sendTids2); //send trial Host the child tids 129 } 130 131 132 //Barrier Sync 133 for (int i=0;i<info.numTasks*info.endIndex.length+info.numTasks; i++) { 134 // Receive a message... 135 jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig); 136 // Unpack the message... 137 String str = message.buffer.upkstr(); 138 System.out.println(str); 139 } 140 jpvmBuffer buf = new jpvmBuffer(); 141 142 } 143 catch(jpvmException ex) { 144 ex.printStackTrace(); 145 }catch(Exception ap) 146 { 147 ap.printStackTrace(); 148 } 149 } 150 151 public Seed idum; 152 153 public double valueSd; // value's standard deviation 154 155 public int trialId = -1; 156 157 public int expId = 0; 158 159 public int saveInt; 160 161 162 public ParToDoc parDoc; 163 164 public Experiment exp; 165 166 public JpvmInfo info; 167 public int [] endIndex; 168 169 public double minDelay; 170 public double backFire; 171 public int seedInt; 172 173 public jpvmTaskId[][] netTids; 174 175 public LinkedList[] intraReceiver ; 176 public RecorderData rdata; 177 178 179 public SimulatorParser pas; 180 181 public void closeHosts() { 182 try { 183 184 for (int i=0;i<info.numTasks; i++) { 185 jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady); 186 int freeId = message.buffer.upkint(); 187 System.out.println("Trial Host "+freeId+" is killed"); 188 jpvmBuffer buf = new jpvmBuffer(); 189 buf.pack(0); 190 info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig); 191 } 192 } 193 catch(jpvmException ex) { 194 ex.printStackTrace(); 195 } 196 } 197 198 public class ToDo { 199 Timer timer; 200 201 public ToDo ( int seconds ) { 202 timer = new Timer ( ) ; 203 timer.schedule ( new ToDoTask ( ) , seconds*1000) ; 204 } 205 206 207 class ToDoTask extends TimerTask { 208 public void run ( ) { 209 if(trialId == netTids.length-1) //no new trial is finished 210 { 211 jpvmBuffer buf = new jpvmBuffer(); 212 try { 213 info.jpvm.pvm_send(buf,info.tids[0],NetMessageTag.checkTime); //check first trial host time 214 } 215 catch(jpvmException ex) { 216 ex.printStackTrace(); 217 System.exit(-1); 218 } 219 } 220 timer.cancel() ; 221 } 222 } 223 224 public void stop() 225 { 226 timer.cancel() ; 227 } 228 } 229 230 /** 231 * @see cnslab.cnsnetwork.FitFun#fitFunction(double[]) fitFunction 232 */ 233 public double fitFunction(double[] para) { 234 double tmpValue=0.0; 235 try { 236 //initialized parameters 237 intraReceiver = new LinkedList[exp.recorder.intraEle.size()*exp.subExp.length]; 238 rdata = new RecorderData(); 239 240 241 //first change the connections for the netHost 242 Document newDoc = parDoc.getDocument(para); 243 NodeList conns = pas.rootElement.getElementsByTagName("Connections"); 244 245 System.out.println("connections num"+conns.getLength()); 246 247 pas.rootElement.removeChild(conns.item(0)); 248 Node dup = pas.document.importNode(newDoc.getDocumentElement() , true); 249 pas.rootElement.appendChild(dup); 250 251 pas.document.normalizeDocument(); //expand everything like save and load 252 253 pas.document.removeChild(pas.documentType); 254 DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance(); 255 DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS"); 256 LSSerializer writer = impl.createLSSerializer(); 257 LSOutput output = impl.createLSOutput(); 258 ByteArrayOutputStream bArray = new ByteArrayOutputStream(); 259 output.setByteStream(bArray); 260 writer.write(pas.document, output); 261 pas.document.appendChild(pas.documentType); 262 byte [] ba = bArray.toByteArray(); 263 264 265 266 seedInt = saveInt; 267 for (int i=0;i<info.numTasks; i++) { 268 // Receive a message... 269 jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady); 270 int freeId = message.buffer.upkint(); 271 System.out.println("Trial Host "+freeId+" change connection"); 272 seedInt = seedInt - info.endIndex.length; 273 274 jpvmBuffer buf = new jpvmBuffer(); 275 buf.pack(seedInt); 276 buf.pack(ba.length); 277 buf.pack(ba, ba.length, 1); 278 info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.changeConnection); 279 // Unpack the message... 280 } 281 282 trialId = -1; 283 expId = 0; 284 int aliveHost = info.numTasks; 285 int totalTrials=0; 286 for(int j=0;j< exp.subExp.length; j++) 287 { 288 totalTrials += exp.subExp[j].repetition; 289 } 290 291 totalTrials = totalTrials * pas.numOfHosts; 292 boolean stop = false; 293 294 boolean badweight=false; 295 296 //Barrier Sync 297 for (int i=0;i<info.numTasks*info.endIndex.length; i++) { 298 // Receive a message... 299 jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig); 300 // Unpack the message... 301 String str = message.buffer.upkstr(); 302 int posi; 303 if((posi=str.indexOf("$")) >=0 ) 304 { 305 badweight = true; 306 double tmpval = Double.parseDouble(str.substring(posi+1)); 307 if(tmpval > tmpValue) { 308 tmpValue = tmpval; 309 valueSd =0.0; 310 } 311 } 312 System.out.println(str); 313 } 314 315 jpvmBuffer buf ; 316 double per=0.0; 317 boolean overflow=false; 318 319 320 if(!badweight) 321 { 322 System.out.println("************ simulation is starting *********************"); 323 //listening and processing 324 ToDo toDo = new ToDo(90); //1min's threshold 325 int countHosts=0; 326 while(!stop) 327 { 328 jpvmMessage m = info.jpvm.pvm_recv(); //receive info from others 329 switch(m.messageTag) 330 { 331 case NetMessageTag.checkTime: //received percentage time and reset netHosts 332 int hostId = m.buffer.upkint(); 333 int eId = m.buffer.upkint(); 334 double root_time = m.buffer.upkdouble(); 335 per = root_time/exp.subExp[eId].trialLength*100; 336 stop = true; //can stop now; 337 overflow=true; 338 for (int i=0;i<info.numTasks; i++) //reset all the netHosts 339 { 340 buf = new jpvmBuffer(); 341 info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.resetNetHost); 342 } 343 System.out.println("Overflow detected"); 344 break; 345 case NetMessageTag.trialReady: 346 int freeId = m.buffer.upkint(); //get free host id; 347 trialId++; 348// System.out.println("Now trialId is "+ trialId); 349 if(expId < exp.subExp.length) 350 { 351 if(trialId == exp.subExp[expId].repetition && expId+1 == exp.subExp.length) 352 { 353 expId++; 354 } 355 else if(trialId == exp.subExp[expId].repetition && expId+1 != exp.subExp.length) 356 { 357 trialId =0; 358 expId++; 359 } 360 } 361 if(expId < exp.subExp.length) // game is still on; 362 { 363 System.out.println("Subexp "+expId+" trial "+ trialId+" freeId "+freeId); 364 buf = new jpvmBuffer(); 365 buf.pack(trialId); 366 buf.pack(expId); 367 info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.oneTrial); 368 } 369 else if( aliveHost != 0 ) // if all the work are done and some hosts are not killed 370 { 371 System.out.println("host "+freeId+" finished his job"); 372 // buf = new jpvmBuffer(); 373 // buf.pack(0); 374 // info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig); 375 //aliveHost--; 376 buf = new jpvmBuffer(); 377 buf.pack(0); 378 info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.tempStopSig); 379 } 380 break; 381 case NetMessageTag.getBackData: 382 countHosts++; 383 if(countHosts== info.endIndex.length) {aliveHost--;countHosts=0;} 384 //System.out.println("Receiving data from "); 385 RecorderData spikes = (RecorderData)m.buffer.upkcnsobj(); 386// System.out.print("Receiving data from "+m.buffer.upkint()+" and alive hosts"+aliveHost+"\n"); 387 //comibne for single unit 388 Set entries = spikes.receiver.entrySet(); 389 Iterator entryIter = entries.iterator(); 390 while (entryIter.hasNext()) { 391 Map.Entry<String, LinkedList<Double> > entry = (Map.Entry<String, LinkedList<Double> >)entryIter.next(); 392 String key = entry.getKey(); // Get the key from the entry. 393 LinkedList<Double> value = entry.getValue(); // Get the value. 394 LinkedList<Double> tmp = rdata.receiver.get(key); 395 if(tmp == null) 396 { 397 rdata.receiver.put(key, tmp=(new LinkedList<Double>())); 398 } 399 //if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory 400 tmp.addAll(value); //put received info into memory 401 } 402 ///comibne for multi unit 403 entries = spikes.multiCounter.entrySet(); 404 entryIter = entries.iterator(); 405 while (entryIter.hasNext()) { 406 Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next(); 407 String key = entry.getKey(); // Get the key from the entry. 408 Integer value = entry.getValue(); // Get the value. 409 Integer tmp = rdata.multiCounter.get(key); 410 if(tmp == null) 411 { 412 rdata.multiCounter.put(key, tmp=(new Integer(0))); 413 } 414 tmp+=value; 415 rdata.multiCounter.put(key, tmp); 416 } 417 entries = spikes.multiCounterAll.entrySet(); 418 entryIter = entries.iterator(); 419 while (entryIter.hasNext()) { 420 Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next(); 421 String key = entry.getKey(); // Get the key from the entry. 422 Integer value = entry.getValue(); // Get the value. 423 Integer tmp = rdata.multiCounterAll.get(key); 424 if(tmp == null) 425 { 426 rdata.multiCounterAll.put(key, tmp=(new Integer(0))); 427 } 428 tmp+=value; 429 rdata.multiCounterAll.put(key, tmp); 430 } 431 //combine for field ele 432 entries = spikes.fieldCounter.entrySet(); 433 entryIter = entries.iterator(); 434 while (entryIter.hasNext()) { 435 Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next(); 436 String key = entry.getKey(); // Get the key from the entry. 437 // System.out.println(key); 438 Integer value = entry.getValue(); // Get the value. 439 Integer tmp = rdata.fieldCounter.get(key); 440 if(tmp == null) 441 { 442 rdata.fieldCounter.put(key, tmp=(new Integer(0))); 443 } 444 tmp+=value; 445 rdata.fieldCounter.put(key, tmp); 446 } 447 //combine for vector ele 448 entries = spikes.vectorCounterX.entrySet(); 449 entryIter = entries.iterator(); 450 while (entryIter.hasNext()) { 451 Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next(); 452 String key = entry.getKey(); // Get the key from the entry. 453 Double value = entry.getValue(); // Get the value. 454 Double tmp = rdata.vectorCounterX.get(key); 455 if(tmp == null) 456 { 457 rdata.vectorCounterX.put(key, tmp=(new Double(0.0))); 458 } 459 tmp+=value; 460 rdata.vectorCounterX.put(key, tmp); 461 } 462 entries = spikes.vectorCounterY.entrySet(); 463 entryIter = entries.iterator(); 464 while (entryIter.hasNext()) { 465 Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next(); 466 String key = entry.getKey(); // Get the key from the entry. 467 Double value = entry.getValue(); // Get the value. 468 Double tmp = rdata.vectorCounterY.get(key); 469 if(tmp == null) 470 { 471 rdata.vectorCounterY.put(key, tmp=(new Double(0.0))); 472 } 473 tmp+=value; 474 rdata.vectorCounterY.put(key, tmp); 475 } 476 break; 477 case NetMessageTag.trialDone: 478 totalTrials--; 479 int res_trial = m.buffer.upkint(); 480 int res_exp = m.buffer.upkint(); 481 // System.out.println("R: "+"E"+res_exp+"T"+res_trial); 482 //RecordBuffer spikes = (RecordBuffer)m.buffer.upkcnsobj(); 483 IntraRecBuffer intra = (IntraRecBuffer)m.buffer.upkcnsobj(); 484 /* 485 // System.out.println(intra); 486 487 Iterator<NetRecordSpike> iter_spike = spikes.buff.iterator(); 488 // System.out.println("spikes"+spikes.buff.size()); 489 while(iter_spike.hasNext()) 490 { 491 NetRecordSpike spike = iter_spike.next(); 492 493 LinkedList<Double> tmp = receiver.get("E"+res_exp+"T"+res_trial+"N"+spike.from); 494 495 if(tmp == null) 496 { 497 receiver.put("E"+res_exp+"T"+res_trial+"N"+spike.from, tmp=(new LinkedList<Double>())); 498 } 499 //if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory 500 tmp.add(spike.time); //put received info into memory 501 //System.out.println("fire: time:"+spike.time+ " index:"+spike.from); 502 } 503 */ 504 505 for(int i = 0; i < intra.neurons.length; i++) 506 { 507 int neu = intra.neurons[i]; 508 int eleId = exp.recorder.intraIndex(neu); 509 LinkedList<IntraInfo> info = intra.buff.get(i); 510 LinkedList<IntraInfo> currList; 511 if((currList=(LinkedList<IntraInfo>)intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()])!=null) 512 { 513 Iterator<IntraInfo> intraData = info.iterator(); 514 Iterator<IntraInfo> thisData = currList.iterator(); 515 while(intraData.hasNext()) 516 { 517 (thisData.next()).plus(intraData.next()); 518 } 519 } 520 else 521 { 522 intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()]=info; 523 } 524 } 525 break; 526 } 527 528 if( aliveHost == 0 && totalTrials==0) 529 { 530 stop = true; 531 break; 532 } 533 } 534 toDo.stop(); 535 536 if(overflow) 537 { 538 tmpValue = (100.0-per)/100.0; 539 valueSd =0.0; 540 //Barrier Sync 541 } 542 else 543 { 544 tmpValue = parDoc.getFitValue(pas,intraReceiver,rdata); 545 double [] bootVal = new double [200]; //200 times to get Sd. 546 for(int i=0; i < 200; i++) 547 { 548 bootVal[i] = parDoc.getRanFitValue(pas,intraReceiver,rdata,idum); 549 } 550 valueSd = FunUtil.sd(bootVal); 551 for (int i=0;i<info.numTasks; i++) { 552 buf = new jpvmBuffer(); 553 info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify); 554 } 555 } 556 557 } 558 else 559 { 560 for (int i=0;i<info.numTasks; i++) { 561 buf = new jpvmBuffer(); 562 info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify); 563 } 564 } 565 566 for (int i=0;i<info.numTasks*info.endIndex.length; i++) { 567 // Receive a message... 568 jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig); 569 // Unpack the message... 570 String str = message.buffer.upkstr(); 571 System.out.println(str); 572 } 573 574 //start working for the next fitting evaluation 575 576 System.out.println("cost:"+tmpValue+" Sd:"+valueSd); 577 } 578 catch(Exception ex) { 579 ex.printStackTrace(); 580 System.exit(-1); 581 }catch(jpvmException ap) 582 { 583 ap.printStackTrace(); 584 System.exit(-1); 585 } 586 return tmpValue; 587 } 588}