001    package cnslab.cnsnetwork;
002    
003    import org.w3c.dom.Document;
004    import org.w3c.dom.Element;
005    import org.w3c.dom.NodeList;
006    import cnslab.cnsmath.Seed;
007    import org.w3c.dom.Node;
008    import org.w3c.dom.bootstrap.DOMImplementationRegistry;
009    import org.w3c.dom.ls.DOMImplementationLS;
010    import org.w3c.dom.ls.LSSerializer;
011    import org.w3c.dom.ls.LSOutput;
012    import java.io.FileOutputStream;
013    import java.io.File;
014    import java.io.ByteArrayOutputStream;
015    import java.io.ByteArrayInputStream;
016    import java.io.InputStream;
017    import org.w3c.dom.DocumentType;
018    import jpvm.jpvmTaskId;
019    import cnslab.cnsnetwork.JpvmInfo;
020    import java.util.LinkedList;
021    import java.util.HashMap;
022    import java.util.Set;
023    import java.util.Map;
024    import jpvm.jpvmEnvironment;
025    import jpvm.jpvmBuffer;
026    import jpvm.jpvmMessage;
027    import jpvm.jpvmException;
028    import java.util.Iterator;
029    import java.util.Timer;
030    import java.util.TimerTask;
031
032    /***********************************************************************
033    * Deprecated.
034    *  
035    * @version
036    *   $Date: 2012-08-04 13:43:22 -0500 (Sat, 04 Aug 2012) $
037    *   $Rev: 104 $
038    *   $Author: croft $
039    * @author
040    *   Yi Dong
041    * @author
042    *   David Wallace Croft
043    ***********************************************************************/
044@Deprecated
045public class FitFun
046{
047        public FitFun(ParToDoc parDoc, Seed idum, String modelFilename, int heapSize )
048        {
049                try {
050                        this.parDoc = parDoc;
051                        this.seedInt = idum.seed;
052                        this.saveInt = idum.seed;
053                        this.idum = idum;
054
055                        //spawn all the slaves and ready for the simulation             
056                        pas = new SimulatorParser (
057                          idum, new File ( "model/" + modelFilename ) );
058                        
059                        pas.parseMapCells();
060                        pas.parseExperiment();
061                        exp = pas.experiment; // experiment infomation
062                        pas.findMinDelay();
063
064                        pas.document.removeChild(pas.documentType);
065                        DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
066                        DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
067                        LSSerializer writer = impl.createLSSerializer();
068                        LSOutput output = impl.createLSOutput();
069                        ByteArrayOutputStream bArray = new ByteArrayOutputStream();
070                        output.setByteStream(bArray);
071                        writer.write(pas.document, output);
072                        pas.document.appendChild(pas.documentType);
073                        byte [] ba = bArray.toByteArray();
074
075
076                        info= new JpvmInfo();
077                        // Enroll in the parallel virtual machine...
078                        info.jpvm = new jpvmEnvironment();
079
080                        // Get my task id...
081                        info.myJpvmTaskId = info.jpvm.pvm_mytid();
082                        System.out.println("Task Id: "+info.myJpvmTaskId.toString());
083
084                        info.numTasks= pas.parallelHost; // total number of trial hosts; 
085                        info.idIndex=info.numTasks; // root id; //not used at all
086                        info.tids = new jpvmTaskId[info.numTasks];
087
088                        // Spawn some  trialHosts
089                        info.jpvm.pvm_spawn("cnslab.cnsnetwork.TrialHost",info.numTasks,info.tids,48);
090                        System.out.println("spawn successfully");
091
092                        jpvmBuffer buf2 = new jpvmBuffer();
093
094                        buf2.pack(info.numTasks);
095                        buf2.pack(info.tids, info.numTasks, 1);
096                        buf2.pack(pas.minDelay);
097
098                        info.jpvm.pvm_mcast(buf2,info.tids,info.numTasks,NetMessageTag.sendTids);
099
100                        System.out.println("All sent");
101
102                        //gernerate all the nethosts :
103                        info.endIndex=pas.layerStructure.nodeEndIndices;
104
105                        netTids =  new jpvmTaskId [info.numTasks][info.endIndex.length] ;
106
107                        for(int i = 0 ; i < info.numTasks; i++)
108                        {
109                                System.out.println("generate child for trialHost "+i);
110                                info.jpvm.pvm_spawn("cnslab.cnsnetwork.NetHostTune",info.endIndex.length,netTids[i],heapSize); //Net Host is to seperate large network into small pieces;
111                                jpvmBuffer buf = new jpvmBuffer();
112                                buf.pack(info.endIndex.length);
113                                buf.pack(netTids[i],info.endIndex.length,1);
114                                buf.pack(info.endIndex,info.endIndex.length,1);
115                                seedInt = seedInt - info.endIndex.length;
116                                buf.pack(seedInt);
117                                buf.pack(info.tids[i]); //parent's tid;
118                                buf.pack(ba.length);
119                                buf.pack(ba, ba.length, 1);
120                                info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.sendTids);
121                        }
122
123                        for(int i = 0 ; i < info.numTasks; i++)
124                        {
125                                jpvmBuffer buf = new jpvmBuffer();
126                                buf.pack(info.endIndex.length);
127                                buf.pack(netTids[i],info.endIndex.length,1);
128                                info.jpvm.pvm_send(buf,info.tids[i],NetMessageTag.sendTids2); //send trial Host the child tids
129                        }
130
131
132                        //Barrier Sync
133                        for (int i=0;i<info.numTasks*info.endIndex.length+info.numTasks; i++) {
134                                // Receive a message...
135                                jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
136                                // Unpack the message...
137                                String str = message.buffer.upkstr();
138                                System.out.println(str);
139                        }
140                        jpvmBuffer buf = new jpvmBuffer();
141
142                }
143                catch(jpvmException ex) {
144                        ex.printStackTrace();
145                }catch(Exception ap)
146                {
147                        ap.printStackTrace();
148                }
149        }
150
151        public Seed idum;
152
153        public double valueSd; // value's standard deviation
154
155        public  int trialId = -1;
156
157        public  int expId = 0;
158
159        public int saveInt;
160
161
162        public ParToDoc parDoc;
163
164        public Experiment exp;
165
166        public JpvmInfo info;
167        public int [] endIndex;
168
169        public double minDelay;
170        public double backFire;
171        public int seedInt;
172
173        public jpvmTaskId[][] netTids;
174
175        public LinkedList[] intraReceiver ;
176        public RecorderData rdata;
177
178
179        public SimulatorParser pas;
180
181        public void closeHosts() {
182                try {
183
184                        for (int i=0;i<info.numTasks; i++) {
185                                jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady);
186                                int freeId = message.buffer.upkint(); 
187                                System.out.println("Trial Host "+freeId+" is killed");
188                                jpvmBuffer buf = new jpvmBuffer();
189                                buf.pack(0);
190                                info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig);
191                        }
192                }
193                catch(jpvmException ex) {
194                        ex.printStackTrace();
195                }
196        }
197
198        public class ToDo  {
199                Timer timer;
200
201                public ToDo ( int seconds )   {
202                        timer = new Timer (  ) ;
203                        timer.schedule ( new ToDoTask (  ) , seconds*1000) ;
204                }
205
206
207                class ToDoTask extends TimerTask  {
208                        public void run (  )   {
209                                if(trialId == netTids.length-1) //no new trial is finished
210                                {
211                                        jpvmBuffer buf = new jpvmBuffer();
212                                        try {
213                                                info.jpvm.pvm_send(buf,info.tids[0],NetMessageTag.checkTime); //check first trial host time
214                                        }
215                                        catch(jpvmException ex) {
216                                                ex.printStackTrace();
217                                                System.exit(-1);
218                                        }
219                                }
220                                timer.cancel() ;
221                        }
222                }
223
224                public void stop()
225                {
226                        timer.cancel() ;
227                }
228        }
229
230        /**
231         * @see cnslab.cnsnetwork.FitFun#fitFunction(double[]) fitFunction
232         */
233        public double fitFunction(double[] para) {
234                double tmpValue=0.0;
235                try {
236                        //initialized parameters
237                        intraReceiver = new LinkedList[exp.recorder.intraEle.size()*exp.subExp.length];
238                        rdata = new RecorderData();
239
240
241                        //first change the connections for the netHost
242                        Document newDoc = parDoc.getDocument(para);
243                        NodeList conns = pas.rootElement.getElementsByTagName("Connections");
244
245                        System.out.println("connections num"+conns.getLength());
246
247                        pas.rootElement.removeChild(conns.item(0));
248                        Node dup = pas.document.importNode(newDoc.getDocumentElement() , true);
249                        pas.rootElement.appendChild(dup);
250
251                        pas.document.normalizeDocument(); //expand everything like save and load
252
253                        pas.document.removeChild(pas.documentType);
254                        DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
255                        DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
256                        LSSerializer writer = impl.createLSSerializer();
257                        LSOutput output = impl.createLSOutput();
258                        ByteArrayOutputStream bArray = new ByteArrayOutputStream();
259                        output.setByteStream(bArray);
260                        writer.write(pas.document, output);
261                        pas.document.appendChild(pas.documentType);
262                        byte [] ba = bArray.toByteArray();
263
264
265
266                        seedInt = saveInt;
267                        for (int i=0;i<info.numTasks; i++) {
268                                // Receive a message...
269                                jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.trialReady);
270                                int freeId = message.buffer.upkint(); 
271                                System.out.println("Trial Host "+freeId+" change connection");
272                                seedInt = seedInt - info.endIndex.length;
273
274                                jpvmBuffer buf = new jpvmBuffer();
275                                buf.pack(seedInt);
276                                buf.pack(ba.length);
277                                buf.pack(ba, ba.length, 1);
278                                info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.changeConnection);
279                                // Unpack the message...
280                        }
281
282                        trialId = -1;
283                        expId = 0;
284                        int aliveHost = info.numTasks;
285                        int totalTrials=0;
286                        for(int j=0;j< exp.subExp.length; j++)
287                        {
288                                totalTrials += exp.subExp[j].repetition;
289                        }
290
291                        totalTrials = totalTrials * pas.numOfHosts; 
292                        boolean stop = false;
293
294                        boolean badweight=false;
295
296                        //Barrier Sync
297                        for (int i=0;i<info.numTasks*info.endIndex.length; i++) {
298                                // Receive a message...
299                                jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
300                                // Unpack the message...
301                                String str = message.buffer.upkstr();
302                                int posi;
303                                if((posi=str.indexOf("$")) >=0 )
304                                {
305                                        badweight = true;
306                                        double tmpval =  Double.parseDouble(str.substring(posi+1));
307                                        if(tmpval > tmpValue) {
308                                                tmpValue = tmpval;
309                                                valueSd =0.0;
310                                        }
311                                }
312                                System.out.println(str);
313                        }
314
315                        jpvmBuffer buf ;
316                        double per=0.0;
317                        boolean overflow=false;
318
319
320                        if(!badweight)
321                        {
322                                System.out.println("************ simulation is starting *********************");
323                                //listening and processing
324                                ToDo toDo = new ToDo(90); //1min's threshold
325                                int countHosts=0;
326                                while(!stop)
327                                {
328                                        jpvmMessage m = info.jpvm.pvm_recv(); //receive info from others
329                                        switch(m.messageTag)
330                                        {
331                                                case NetMessageTag.checkTime: //received percentage time and reset netHosts
332                                                        int hostId = m.buffer.upkint();
333                                                        int eId = m.buffer.upkint();
334                                                        double root_time = m.buffer.upkdouble();
335                                                        per = root_time/exp.subExp[eId].trialLength*100;
336                                                        stop = true; //can stop now;
337                                                        overflow=true;
338                                                        for (int i=0;i<info.numTasks; i++) //reset all the netHosts
339                                                        {
340                                                                buf = new jpvmBuffer();
341                                                                info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.resetNetHost);
342                                                        }
343                                                        System.out.println("Overflow detected");
344                                                        break;
345                                                case NetMessageTag.trialReady:
346                                                        int freeId = m.buffer.upkint(); //get free host id;
347                                                        trialId++;
348//                                                      System.out.println("Now trialId is "+ trialId);
349                                                        if(expId < exp.subExp.length)
350                                                        {
351                                                                if(trialId == exp.subExp[expId].repetition && expId+1 == exp.subExp.length)
352                                                                {       
353                                                                        expId++;
354                                                                }
355                                                                else if(trialId == exp.subExp[expId].repetition && expId+1 != exp.subExp.length)                                                           
356                                                                {       
357                                                                        trialId =0;
358                                                                        expId++;
359                                                                }
360                                                        }
361                                                        if(expId < exp.subExp.length)  // game is still on;
362                                                        {
363                                                                System.out.println("Subexp "+expId+" trial "+ trialId+" freeId "+freeId);
364                                                                buf = new jpvmBuffer();
365                                                                buf.pack(trialId);
366                                                                buf.pack(expId);
367                                                                info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.oneTrial);
368                                                        }
369                                                        else if( aliveHost != 0 ) // if all the work are done and some hosts are not killed
370                                                        {
371                                                                System.out.println("host "+freeId+" finished his job");
372                                                                //                                                      buf = new jpvmBuffer();
373                                                                //                                                      buf.pack(0);
374                                                                //                                                      info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.stopSig);
375                                                                //aliveHost--;
376                                                                buf = new jpvmBuffer();
377                                                                buf.pack(0);
378                                                                info.jpvm.pvm_send(buf,info.tids[freeId],NetMessageTag.tempStopSig);                                          
379                                                        }
380                                                        break;
381                                                case NetMessageTag.getBackData:
382                                                        countHosts++;
383                                                        if(countHosts== info.endIndex.length) {aliveHost--;countHosts=0;}
384                                                        //System.out.println("Receiving data from ");
385                                                        RecorderData spikes = (RecorderData)m.buffer.upkcnsobj();
386//                                                      System.out.print("Receiving data from "+m.buffer.upkint()+" and alive hosts"+aliveHost+"\n");
387                                                        //comibne for single unit
388                                                        Set entries = spikes.receiver.entrySet();
389                                                        Iterator entryIter = entries.iterator();
390                                                        while (entryIter.hasNext()) {
391                                                                Map.Entry<String, LinkedList<Double> > entry = (Map.Entry<String, LinkedList<Double> >)entryIter.next();
392                                                                String key = entry.getKey();  // Get the key from the entry.
393                                                                LinkedList<Double> value = entry.getValue();  // Get the value.
394                                                                LinkedList<Double> tmp = rdata.receiver.get(key);
395                                                                if(tmp == null)
396                                                                {
397                                                                        rdata.receiver.put(key, tmp=(new LinkedList<Double>()));
398                                                                }
399                                                                //if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory
400                                                                tmp.addAll(value); //put received info into memory
401                                                        }
402                                                        ///comibne for multi unit
403                                                        entries = spikes.multiCounter.entrySet();
404                                                        entryIter = entries.iterator();
405                                                        while (entryIter.hasNext()) {
406                                                                Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
407                                                                String key = entry.getKey();  // Get the key from the entry.
408                                                                Integer value = entry.getValue();  // Get the value.
409                                                                Integer tmp = rdata.multiCounter.get(key);
410                                                                if(tmp == null)
411                                                                {
412                                                                        rdata.multiCounter.put(key, tmp=(new Integer(0)));
413                                                                }
414                                                                tmp+=value;
415                                                                rdata.multiCounter.put(key, tmp);
416                                                        }
417                                                        entries = spikes.multiCounterAll.entrySet();
418                                                        entryIter = entries.iterator();
419                                                        while (entryIter.hasNext()) {
420                                                                Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
421                                                                String key = entry.getKey();  // Get the key from the entry.
422                                                                Integer value = entry.getValue();  // Get the value.
423                                                                Integer tmp = rdata.multiCounterAll.get(key);
424                                                                if(tmp == null)
425                                                                {
426                                                                        rdata.multiCounterAll.put(key, tmp=(new Integer(0)));
427                                                                }
428                                                                tmp+=value;
429                                                                rdata.multiCounterAll.put(key, tmp);
430                                                        }
431                                                        //combine for field ele
432                                                        entries = spikes.fieldCounter.entrySet();
433                                                        entryIter = entries.iterator();
434                                                        while (entryIter.hasNext()) {
435                                                                Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)entryIter.next();
436                                                                String key = entry.getKey();  // Get the key from the entry.
437                                                                //                                                      System.out.println(key);
438                                                                Integer value = entry.getValue();  // Get the value.
439                                                                Integer tmp = rdata.fieldCounter.get(key);
440                                                                if(tmp == null)
441                                                                {
442                                                                        rdata.fieldCounter.put(key, tmp=(new Integer(0)));
443                                                                }
444                                                                tmp+=value;
445                                                                rdata.fieldCounter.put(key, tmp);
446                                                        }
447                                                        //combine for vector ele
448                                                        entries = spikes.vectorCounterX.entrySet();
449                                                        entryIter = entries.iterator();
450                                                        while (entryIter.hasNext()) {
451                                                                Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next();
452                                                                String key = entry.getKey();  // Get the key from the entry.
453                                                                Double value = entry.getValue();  // Get the value.
454                                                                Double tmp = rdata.vectorCounterX.get(key);
455                                                                if(tmp == null)
456                                                                {
457                                                                        rdata.vectorCounterX.put(key, tmp=(new Double(0.0)));
458                                                                }
459                                                                tmp+=value;
460                                                                rdata.vectorCounterX.put(key, tmp);
461                                                        }
462                                                        entries = spikes.vectorCounterY.entrySet();
463                                                        entryIter = entries.iterator();
464                                                        while (entryIter.hasNext()) {
465                                                                Map.Entry<String, Double> entry = (Map.Entry<String, Double>)entryIter.next();
466                                                                String key = entry.getKey();  // Get the key from the entry.
467                                                                Double value = entry.getValue();  // Get the value.
468                                                                Double tmp = rdata.vectorCounterY.get(key);
469                                                                if(tmp == null)
470                                                                {
471                                                                        rdata.vectorCounterY.put(key, tmp=(new Double(0.0)));
472                                                                }
473                                                                tmp+=value;
474                                                                rdata.vectorCounterY.put(key, tmp);
475                                                        }
476                                                        break;
477                                                case NetMessageTag.trialDone:
478                                                        totalTrials--;
479                                                        int res_trial = m.buffer.upkint();
480                                                        int res_exp = m.buffer.upkint();
481                                                        //      System.out.println("R: "+"E"+res_exp+"T"+res_trial);
482                                                        //RecordBuffer spikes = (RecordBuffer)m.buffer.upkcnsobj();
483                                                        IntraRecBuffer intra = (IntraRecBuffer)m.buffer.upkcnsobj();
484                                                        /*
485                                                        //                      System.out.println(intra);
486
487                                                        Iterator<NetRecordSpike> iter_spike = spikes.buff.iterator();
488                                                        //                                      System.out.println("spikes"+spikes.buff.size());
489                                                        while(iter_spike.hasNext())
490                                                        {
491                                                        NetRecordSpike spike = iter_spike.next();
492
493                                                        LinkedList<Double> tmp = receiver.get("E"+res_exp+"T"+res_trial+"N"+spike.from);
494
495                                                        if(tmp == null)
496                                                        {
497                                                        receiver.put("E"+res_exp+"T"+res_trial+"N"+spike.from, tmp=(new LinkedList<Double>()));
498                                                        }
499                                                        //if( spike.time !=0.0 ) tmp.add(spike.time); //put received info into memory
500                                                        tmp.add(spike.time); //put received info into memory
501                                                        //System.out.println("fire: time:"+spike.time+ " index:"+spike.from);
502                                                        }
503                                                        */
504
505                                                        for(int i = 0; i < intra.neurons.length; i++)
506                                                        {
507                                                                int neu = intra.neurons[i];
508                                                                int eleId = exp.recorder.intraIndex(neu);
509                                                                LinkedList<IntraInfo> info = intra.buff.get(i);
510                                                                LinkedList<IntraInfo> currList;
511                                                                if((currList=(LinkedList<IntraInfo>)intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()])!=null)
512                                                                {
513                                                                        Iterator<IntraInfo> intraData = info.iterator();
514                                                                        Iterator<IntraInfo> thisData = currList.iterator();
515                                                                        while(intraData.hasNext())
516                                                                        {
517                                                                                (thisData.next()).plus(intraData.next());
518                                                                        }
519                                                                }
520                                                                else
521                                                                {
522                                                                        intraReceiver[eleId+res_exp*exp.recorder.intraEle.size()]=info; 
523                                                                }
524                                                        }
525                                                        break;
526                                        }
527
528                                        if( aliveHost == 0 && totalTrials==0) 
529                                        {
530                                                stop = true;
531                                                break;
532                                        }
533                                }
534                                toDo.stop();
535
536                                if(overflow)
537                                {
538                                        tmpValue = (100.0-per)/100.0;
539                                        valueSd =0.0;
540                                        //Barrier Sync
541                                }
542                                else
543                                {
544                                        tmpValue = parDoc.getFitValue(pas,intraReceiver,rdata);
545                                        double [] bootVal = new double [200]; //200 times to get Sd.
546                                        for(int i=0; i < 200; i++)
547                                        {
548                                                bootVal[i] = parDoc.getRanFitValue(pas,intraReceiver,rdata,idum);
549                                        }
550                                        valueSd = FunUtil.sd(bootVal);
551                                        for (int i=0;i<info.numTasks; i++) {
552                                                buf = new jpvmBuffer();
553                                                info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify);
554                                        }
555                                }
556
557                        }
558                        else
559                        {
560                                for (int i=0;i<info.numTasks; i++) {
561                                        buf = new jpvmBuffer();
562                                        info.jpvm.pvm_mcast(buf,netTids[i],info.endIndex.length,NetMessageTag.netHostNotify);
563                                }
564                        }
565
566                        for (int i=0;i<info.numTasks*info.endIndex.length; i++) {
567                                // Receive a message...
568                                jpvmMessage message = info.jpvm.pvm_recv(NetMessageTag.readySig);
569                                // Unpack the message...
570                                String str = message.buffer.upkstr();
571                                System.out.println(str);
572                        }
573
574                        //start working for the next fitting evaluation
575
576                        System.out.println("cost:"+tmpValue+" Sd:"+valueSd);
577                }
578                catch(Exception ex) {
579                        ex.printStackTrace();
580                        System.exit(-1);
581                }catch(jpvmException ap)
582                {
583                        ap.printStackTrace();
584                        System.exit(-1);
585                }
586                return tmpValue;
587        }
588}