001 package cnslab.cnsnetwork; 002 003 import java.util.*; 004 005 import cnslab.cnsmath.*; 006 import edu.jhu.mb.ernst.model.Synapse; 007 import edu.jhu.mb.ernst.util.slot.Slot; 008 009 /*********************************************************************** 010 * MN model with STDP implemented. 011 * 012 * See Mihalas and Niebur 2009 paper. 013 * 014 * @version 015 * $Date: 2012-08-04 13:43:22 -0500 (Sat, 04 Aug 2012) $ 016 * $Rev: 104 $ 017 * $Author: croft $ 018 * @author 019 * Yi Dong 020 * @author 021 * David Wallace Croft, M.Sc. 022 * @author 023 * Jeremy Cohen 024 ***********************************************************************/ 025 public final class MiNiNeuron 026 implements Neuron 027 //////////////////////////////////////////////////////////////////////// 028 //////////////////////////////////////////////////////////////////////// 029 { 030 031 private double timeOfNextFire; 032 033 private double timeOfLastUpdate; 034 035 /** whether the neuron is recordable or not */ 036 private boolean record; 037 038 /** table for STDP that stores presynaptic event time and weight */ 039 private Map<Synapse, TimeWeight> histTable; 040 041 /** the epsilon of a final firing time calculation */ 042 public static double tEPS = 1e-12; 043 044 /** the criterion for removing synapses: if a synaptic channel does not 045 get an input for a long time such that its current decays below rEPS, 046 it is removed from computations of the voltage and future updates up 047 until it gets an input */ 048 public static double rEPS = 1e-22; 049 050 /** the cost associated with one neuron update (see D’Haene, 2009) */ 051 public static double cost_Update = 10; 052 053 /** the cost associated with one queue schedule (see D'Haene, 2009) */ 054 public static double cost_Schedule = 1; 055 056 /** the expected cost of inserting a preliminary fire estimate into the 057 * event queue (as opposed to performing another NR iteration) */ 058 public static double cost; 059 060 /** running average of the time interval between input spikes */ 061 public double tAvg; 062 063 /** last time the neuron received a spike */ 064 public double lastInputTime; 065 066 /** last time the neuron fired; negative means no last spike */ 067 public double lastFireTime = -1.0; 068 069 /** clamp the membrane voltage during absolute refractory period */ 070 public double clampVoltage; 071 072 /** clamp the threshold during absolute refractory period */ 073 public double clampThreshold; 074 075 /** whether the next fire event is a final calculation, 076 * as opposed to a preliminary estimate (see D'Haene, 2009) */ 077 public boolean fire; 078 079 /** long imposes a constraint on the maximum number of hosts to be 64 */ 080 public long tHost; 081 082 /** neuron parameters */ 083 public MiNiNeuronPara para; 084 085 /** linked list for state variables */ 086 public TreeSet<State> state; 087 088 /** 089 * two-dimensional array pointing to the state variables. The first 090 * dimension corresponds to type -- NG, NB, NJ_SPIKE, and NJ_SYNAPSE, in 091 * that order. The second dimension corresponds to index. The NG and NB 092 * arrays have only one element. 093 */ 094 public State [ ] [ ] sta_p; 095 096 097 //////////////////////////////////////////////////////////////////////// 098 // inner classes 099 //////////////////////////////////////////////////////////////////////// 100 101 public final class State 102 implements Comparable<State> 103 //////////////////////////////////////////////////////////////////////// 104 //////////////////////////////////////////////////////////////////////// 105 { 106 /** 107 * values of <code>type</code> 108 */ 109 public final static int 110 NG = 0, 111 NB = 1, 112 NJ_SPIKE = 2, 113 NJ_SYNAPSE = 3; 114 115 /** 116 * The time when this State was last updated. 117 */ 118 public double time; 119 120 /** 121 * The value of this State at the time when it was last updated. 122 */ 123 public double value; 124 125 /** 126 * The term in the thresholdDiff function that this State represents 127 * (see Mihalas and Niebur, 2009, Equation 3.5.) 128 */ 129 public int type; 130 131 /** 132 * The index of this State among other States of the same type. 133 */ 134 public int index; 135 136 //////////////////////////////////////////////////////////////////////// 137 //////////////////////////////////////////////////////////////////////// 138 139 public State ( 140 final double time, 141 final double value, 142 final int type, 143 final int index) 144 //////////////////////////////////////////////////////////////////////// 145 { 146 this.time = time; 147 148 this.value = value; 149 150 this.type = type; 151 152 this.index = index; 153 } 154 155 //////////////////////////////////////////////////////////////////////// 156 //////////////////////////////////////////////////////////////////////// 157 158 /** 159 * Sorts in descending order of decay rate. 160 */ 161 @Override 162 public int compareTo ( final State arg0 ) 163 //////////////////////////////////////////////////////////////////////// 164 { 165 if ( para.allDecays [ this.type ] [ this.index ] < 166 para.allDecays [ arg0.type ] [ arg0.index ]) 167 { 168 return 1; 169 } 170 else if ( para.allDecays [ this.type ] [ this.index ] > 171 para.allDecays [ arg0.type ] [ arg0.index ]) 172 { 173 return -1; 174 } 175 176 if ( this.type < arg0.type ) 177 { 178 return -1; 179 } 180 else if ( this.type > arg0.type ) 181 { 182 return 1; 183 } 184 185 if (this.index < arg0.index) 186 { 187 return -1; 188 } 189 else if (this.index > arg0.index) 190 { 191 return 1; 192 } 193 194 else 195 { 196 return 0; 197 } 198 } 199 200 @Override 201 public String toString ( ) 202 //////////////////////////////////////////////////////////////////////// 203 { 204 return 205 "time:" + time 206 + " value:" + value 207 + " type:" + type 208 + " index:" + index 209 + " decay:" + para.allDecays [ type ] [ index ]; 210 } 211 212 //////////////////////////////////////////////////////////////////////// 213 //////////////////////////////////////////////////////////////////////// 214 } 215 216 //////////////////////////////////////////////////////////////////////// 217 // constructor methods 218 //////////////////////////////////////////////////////////////////////// 219 220 public MiNiNeuron ( final MiNiNeuronPara para ) 221 //////////////////////////////////////////////////////////////////////// 222 { 223 this.para = para; 224 } 225 226 //////////////////////////////////////////////////////////////////////// 227 // interface Neuron accessor methods 228 //////////////////////////////////////////////////////////////////////// 229 230 /** 231 * Used mainly for intracellular recording. 232 * 233 * @param currTime the current time. 234 * @return the synaptic currents in this neuron at time currTime. 235 */ 236 @Override 237 public double [ ] getCurr ( final double currTime ) 238 //////////////////////////////////////////////////////////////////////// 239 { 240 final double [ ] out = 241 new double [ sta_p [ State.NJ_SYNAPSE ].length ]; 242 243 for ( int a = 0; a < sta_p [ State.NJ_SYNAPSE ].length; a++ ) 244 { 245 State state = sta_p [ State.NJ_SYNAPSE ] [ a ]; 246 247 if ( state == null ) 248 { 249 out [ a ] = 0; 250 } 251 else 252 { 253 // Convert NJ_SYNAPSE state to synapse current, and update to 254 // currTime. 255 256 out [ a ] = state.value * Math.exp ( 257 -( currTime - state.time ) 258 * para.allDecays [ State.NJ_SYNAPSE ] [ a ] ) 259 / para.SYNAPSE_CURRENT_TO_SYNAPSE_NJ [ a ]; 260 } 261 } 262 263 return out; 264 } 265 266 /** 267 * Used mainly for intracellular recording. 268 * 269 * @param currTime the current time. 270 * @return the membrane voltage in this neuron at time currTime. 271 */ 272 @Override 273 public double getMemV ( final double currTime ) 274 //////////////////////////////////////////////////////////////////////// 275 { 276 if ( currTime < timeOfLastUpdate ) 277 { 278 // If input comes inside the refractory period, 279 // the membrane voltage remains unchanged. 280 281 return membraneVoltage ( timeOfLastUpdate ); 282 } 283 else 284 { 285 return membraneVoltage (currTime); 286 } 287 } 288 289 @Override 290 public boolean getRecord ( ) 291 //////////////////////////////////////////////////////////////////////// 292 { 293 return this.record; 294 } 295 296 @Override 297 public long getTHost ( ) 298 //////////////////////////////////////////////////////////////////////// 299 { 300 return tHost; 301 } 302 303 @Override 304 public double getTimeOfNextFire ( ) 305 //////////////////////////////////////////////////////////////////////// 306 { 307 return this.timeOfNextFire; 308 } 309 310 @Override 311 public boolean isSensory ( ) 312 //////////////////////////////////////////////////////////////////////// 313 { 314 return false; 315 } 316 317 @Override 318 public boolean realFire ( ) 319 //////////////////////////////////////////////////////////////////////// 320 { 321 // Whether the next fire event is a final calculation, 322 // as opposed to a preliminary estimate (see D'Haene, 2009). 323 324 return fire; 325 } 326 327 //////////////////////////////////////////////////////////////////////// 328 // interface Neuron mutator methods 329 //////////////////////////////////////////////////////////////////////// 330 331 @Override 332 public void setRecord ( final boolean record ) 333 //////////////////////////////////////////////////////////////////////// 334 { 335 this.record = record; 336 } 337 338 @Override 339 public void setTHost ( final long id ) 340 //////////////////////////////////////////////////////////////////////// 341 { 342 this.tHost = id; 343 } 344 345 @Override 346 public void setTimeOfNextFire ( final double timeOfNextFire ) 347 //////////////////////////////////////////////////////////////////////// 348 { 349 this.timeOfNextFire = timeOfNextFire; 350 } 351 352 //////////////////////////////////////////////////////////////////////// 353 // interface Neuron lifecycle methods 354 //////////////////////////////////////////////////////////////////////// 355 356 @Override 357 public void init ( 358 final int expid, 359 final int trialid, 360 final Seed idum, 361 final Network net, 362 final int id ) 363 //////////////////////////////////////////////////////////////////////// 364 { 365 // initialization all the initial parameters. 366 367 cost = Math.log ( cost_Schedule / cost_Update + 1.0 ); 368 369 this.tAvg = 0.005; // initial value is 5 ms; 370 371 this.lastInputTime = 0.0; 372 373 this.timeOfLastUpdate = 0.0; 374 375 this.timeOfNextFire = -1; 376 377 double initialThreshold = para.ini_threshold; 378 379 // Pick the initial membrane voltage from a random, uniform spread. 380 381 double initialVoltage = 382 para.ini_mem + para.ini_memVar * Cnsran.ran2 ( idum ); 383 384 double [ ] initialSpikeCurrents = 385 new double [ para.ini_spike_curr.length ]; 386 387 for ( int a = 0; a < initialSpikeCurrents.length; a++ ) 388 initialSpikeCurrents [ a ] = para.ini_spike_curr [ a ]; 389 390 // Set up the history table to moniter of synaptic activity for STDP 391 392 histTable = new HashMap<Synapse,TimeWeight> ( ); 393 394 // Last time the neuron fired. 395 396 this.lastFireTime = -1.0; 397 398 // Set up state variables. 399 400 // Create linked table for states. 401 402 state = new TreeSet<State> ( ); 403 404 sta_p = new State [ 4 ] [ ]; 405 406 // Create the NG term. 407 408 double ngTerm = para.NG_BASE; 409 410 ngTerm += initialVoltage * para.VOLTAGE_TO_NG; 411 412 for (int a = 0; a < initialSpikeCurrents.length; a ++) 413 { 414 ngTerm += 415 initialSpikeCurrents [ a ] * para.SPIKE_CURRENT_TO_NG [ a ]; 416 } 417 418 sta_p [ State.NG ] = new State [ ] { 419 new State ( 0, ngTerm, State.NG, 0 ) 420 }; 421 422 state.add ( sta_p [ State.NG ] [ 0 ] ); 423 424 // Create the NB term. 425 426 double nbTerm = para.NB_BASE; 427 428 nbTerm += initialThreshold * para.THRESHOLD_TO_NB; 429 nbTerm += initialVoltage * para.VOLTAGE_TO_NB; 430 431 sta_p [ State.NB ] = new State [ ] { 432 new State ( 0, nbTerm, State.NB , 0 ) 433 }; 434 435 state.add ( sta_p [ State.NB ] [ 0 ] ); 436 437 // Create the NJ_SPIKE terms. 438 439 sta_p [ State.NJ_SPIKE ] = 440 new State [ para.allDecays [ State.NJ_SPIKE ].length ]; 441 442 for ( int i = 0; i < sta_p [ State.NJ_SPIKE ].length; i++ ) 443 { 444 double njTerm = 445 initialSpikeCurrents [ i ] * para.SPIKE_CURRENT_TO_SPIKE_NJ [ i ]; 446 447 sta_p [ State.NJ_SPIKE ] [ i ] = 448 new State ( 0 , njTerm , State.NJ_SPIKE , i ); 449 450 state.add ( sta_p [ State.NJ_SPIKE ] [ i ] ); 451 } 452 453 // Create the NJ_SYNAPSE terms. 454 455 sta_p [ State.NJ_SYNAPSE ] = 456 new State [ para.allDecays [ State.NJ_SYNAPSE ].length ]; 457 458 for ( int i = 0; i < sta_p [ State.NJ_SYNAPSE ].length; i++ ) 459 { 460 sta_p [ State.NJ_SYNAPSE ] [ i ] = 461 new State ( 0 , 0 , State.NJ_SYNAPSE , i ); 462 463 state.add ( sta_p [ State.NJ_SYNAPSE ] [ i ] ); 464 } 465 466 // Schedule the first fire event, if one should exist. 467 468 // Use a modified Newton-Raphson (NR) root-finding iteration method to 469 // find the first zero of the thresholdDiff function. 470 471 double deltaT, thresholdDiff; 472 473 double deriv = safeDerivative ( 0.0 ); 474 475 // Ever-increasing estimate of the absolute time of the next fire 476 // event. 477 478 double nextTime = ( -( initialVoltage - initialThreshold ) / deriv ); 479 480 deltaT = nextTime; 481 482 while ( deriv > 0 && nextTime < 1.0 483 && !( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS ) ) 484 { 485 // Safe derivative of the thresholdDiff function at time = nextTime. 486 487 deriv = safeDerivative ( nextTime ); 488 489 // Value of the thresholdDiff function at time = nextTime. 490 491 thresholdDiff = thresholdDiff ( nextTime ); 492 493 // As per the NR algorithm. 494 495 deltaT = ( -thresholdDiff / deriv ); 496 497 nextTime += deltaT; 498 } 499 500 // If the derivative is decreasing or if the algorithm has gone far 501 // enough, conclude that this neuron will not fire. 502 503 if ( deriv < 0 || nextTime > 1.0 ) 504 { 505 fire = false; 506 507 nextTime = -1.0; 508 } 509 510 // If deltaT is within the proper precision, schedule an actual fire 511 // event. 512 513 else if ( ( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS ) ) 514 { 515 fire = true; 516 } 517 else 518 { 519 // This shouldn't happen. 520 } 521 522 if ( nextTime < 0 ) 523 { 524 // A negative value for nextTime means that this neuron will not fire. 525 } 526 else 527 { 528 final Slot<FireEvent> fireEventSlot = net.getFireEventSlot ( ); 529 530 if ( net.getClass ( ).getName().equals ( 531 "cnslab.cnsnetwork.ANetwork" ) ) 532 { 533 fireEventSlot.offer ( 534 new AFireEvent ( 535 id, 536 nextTime, 537 net.info.idIndex, 538 ( ( ANetwork ) net ).aData.getId ( ( ANetwork ) net ) ) ); 539 } 540 else if ( net.getClass ( ).getName ( ).equals ( 541 "cnslab.cnsnetwork.Network" ) ) 542 { 543 fireEventSlot.offer ( new FireEvent ( id, nextTime ) ); 544 } 545 else 546 { 547 throw new RuntimeException ( 548 "Other Network Class doesn't exist" ); 549 } 550 551 timeOfNextFire = nextTime; 552 } 553 } 554 555 @Override 556 public double updateFire ( ) 557 //////////////////////////////////////////////////////////////////////// 558 { 559 /* 560 * I. If this is a real fire event: 561 * A. Update spike time dependent plasticity variables. 562 * B. Update the state variables to time at the end of the 563 * refractory period. 564 * C. Reset the membrane voltage, membrane voltage threshold, and 565 * spike-induced currents according to the reset rules of the 566 * Mihalas-Niebur model. 567 * II. Schedule the next fire event. 568 */ 569 // The time to start looking for the next fire event. 570 571 double baselineTime; 572 573 // The voltage and threshold at baselineTime. 574 575 double nowVoltage, nowThreshold; 576 577 if ( fire ) // Only proceed if this is an actual fire event. 578 { 579 580 // Update STDP variables. 581 582 for ( final Map.Entry<Synapse,TimeWeight> entry 583 : histTable.entrySet ( ) ) 584 { 585 final Synapse syn = entry.getKey ( ); 586 587 // Get the relative weight. 588 589 final TimeWeight tw = entry.getValue(); 590 591 // Channel 0 has STDP. 592 593 if ( syn.getType ( ) == 0 ) 594 { 595 // LTP only for close spikes. 596 597 if ( lastFireTime < tw.time ) 598 { 599 // Update the weight. 600 601 tw.weight = tw.weight 602 * ( 1 + para.Alpha_LTP 603 * Math.exp ( -para.K_LTP 604 * ( timeOfNextFire-tw.time ) ) ); 605 } 606 } 607 } 608 609 // Store the old neuron fire time. 610 611 lastFireTime = timeOfNextFire; 612 613 // Back up the current threshold. 614 615 final double thresholdBackup 616 = membraneVoltage ( timeOfNextFire ) 617 - thresholdDiff ( timeOfNextFire ); 618 619 // Update the G term to the time at the end of the refractory period. 620 621 sta_p [ State.NG ] [ 0 ].value *= Math.exp ( 622 -( timeOfNextFire - sta_p [ State.NG ] [ 0 ].time + para.ABSREF ) 623 * para.allDecays [ State.NG ] [ 0 ]); 624 625 sta_p [ State.NG ] [ 0 ].time = timeOfNextFire + para.ABSREF; 626 627 // Update the B term to the time at the end of the refractory period. 628 629 sta_p [ State.NB ] [ 0 ].value *= Math.exp ( 630 -( timeOfNextFire - sta_p [ State.NB ] [ 0 ].time + para.ABSREF ) 631 * para.allDecays [ State.NB ] [ 0 ]); 632 633 sta_p [ State.NB ] [ 0 ].time = timeOfNextFire + para.ABSREF; 634 635 // Reset the membrane voltage and voltage threshold. 636 637 // Calculate what the voltage and threshold will be 638 // at the end of the refractory period. 639 640 final double 641 futureVoltage = membraneVoltage ( timeOfNextFire + para.ABSREF ); 642 643 final double 644 futureThresholdGap = thresholdDiff (timeOfNextFire + para.ABSREF); 645 646 final double futureThreshold = futureVoltage - futureThresholdGap; 647 648 // Calculate the reset voltage and threshold, according 649 // to the reset rules of the model. 650 651 final double resetVoltage = para.VRESET; 652 653 final double resetThreshold = 654 Math.max(para.RRESET, thresholdBackup + para.THRESHOLDADD); 655 656 // Calculate the necessary offsets. 657 658 final double voltageOffset = resetVoltage - futureVoltage; 659 660 final double thresholdOffset = resetThreshold - futureThreshold; 661 662 // Incorporate the offsets into the state variables. 663 664 sta_p [ State.NG ] [ 0 ].value += 665 voltageOffset * para.VOLTAGE_TO_NG; 666 667 sta_p [ State.NB ] [ 0 ].value += 668 voltageOffset * para.VOLTAGE_TO_NB; 669 670 sta_p [ State.NB ] [ 0 ].value += 671 thresholdOffset * para.THRESHOLD_TO_NB; 672 673 // Clamp the voltage and threshold to the reset values. 674 675 clampVoltage = resetVoltage; 676 677 clampThreshold = resetThreshold; 678 679 // Reset the spike-induced currents. 680 681 for (int a = 0; a < para.allDecays [ State.NJ_SPIKE ].length; a++) 682 { 683 // Reset a current only if its reset rule is not the identity. 684 685 if ( para.SPIKE_RATIO [ a ] != 1.0 || para.SPIKE_ADD [ a ] != 0 ) 686 { 687 // Update the current to the end of the refractory period. 688 689 // If the current is inactive, initialize its state variable. 690 691 if ( sta_p [ State.NJ_SPIKE ] [ a ] == null ) 692 { 693 sta_p [ State.NJ_SPIKE ] [ a ] = new State ( 694 timeOfNextFire + para.ABSREF, 695 0, 696 State.NJ_SPIKE, 697 a); 698 699 state.add ( sta_p [ State.NJ_SPIKE ] [ a ]); 700 } 701 else 702 { 703 sta_p [ State.NJ_SPIKE ] [ a ].value *= Math.exp ( 704 -( timeOfNextFire - 705 sta_p [ State.NJ_SPIKE ] [ a ].time + para.ABSREF ) 706 * para.allDecays [ State.NJ_SPIKE ] [ a ] ); 707 708 sta_p [ State.NJ_SPIKE ] [ a ].time = 709 timeOfNextFire + para.ABSREF; 710 } 711 712 // Calculate what this current will be 713 // at the end of the refractory period. 714 715 final double futureCurrent = 716 sta_p [ State.NJ_SPIKE ] [ a ].value 717 / para.SPIKE_CURRENT_TO_SPIKE_NJ [ a ]; 718 719 // Calculate the reset current, according to the reset rule 720 // of the model. 721 722 final double resetCurrent = 723 ( para.SPIKE_RATIO [ a ]* futureCurrent 724 * Math.exp ( para.ABSREF * 725 para.allDecays [ State.NJ_SPIKE ] [ a ] ) 726 + para.SPIKE_ADD [ a ] ) 727 * Math.exp ( -para.ABSREF * 728 para.allDecays [ State.NJ_SPIKE ] [ a ] ); 729 730 // Calculate the necessary offset. 731 732 final double currentOffset = resetCurrent - futureCurrent; 733 734 // Incorporate the offset into the state variables. 735 736 sta_p [ State.NG ] [ 0 ].value += 737 currentOffset * para.SPIKE_CURRENT_TO_NG [ a ]; 738 739 sta_p [ State.NB ] [ 0 ].value += 740 currentOffset * para.SPIKE_CURRENT_TO_NB [ a ]; 741 742 sta_p [ State.NJ_SPIKE ] [ a ].value += 743 currentOffset * para.SPIKE_CURRENT_TO_SPIKE_NJ [ a ]; 744 } 745 } 746 747 timeOfLastUpdate = timeOfNextFire + para.ABSREF; 748 749 baselineTime = timeOfNextFire + para.ABSREF; 750 751 nowVoltage = resetVoltage; 752 753 nowThreshold = resetThreshold; 754 } 755 else // If this fire event was only a preliminary prediction. 756 { 757 baselineTime = timeOfNextFire; 758 759 nowVoltage = thresholdDiff(baselineTime); 760 761 nowThreshold = 0; 762 } 763 764 // Schedule the next fire event, if one should exist. 765 766 // Use a modified Newton-Raphson (NR) root-finding iteration method to 767 // find the first zero of the thresholdDiff function. 768 769 double deltaT, thresholdDiff; 770 771 double deriv = safeDerivative ( baselineTime ); 772 773 // Ever-increasing estimate of the time until the next fire event. 774 775 double nextTime = ( -( nowVoltage - nowThreshold ) ) / deriv; 776 777 deltaT = nextTime; 778 779 // Repeat the NR iteration until either the derivative turns negative 780 // or the predicted fire time grows so far away that an input spike 781 // is likely to arrive in the intervening time, rendering any further 782 // calculations useless (see D'Haene, 2009). 783 784 while ( deriv > 0 && nextTime < cost * tAvg ) 785 { 786 // Safe derivative of the thresholdDiff function. 787 788 deriv = safeDerivative ( baselineTime + nextTime ); 789 790 // Value of the thresholdDiff function. 791 792 thresholdDiff = thresholdDiff ( baselineTime + nextTime ); 793 794 // As per the NR algorithm. 795 796 deltaT = ( -thresholdDiff / deriv ); 797 798 // If deltaT is within the proper precision, schedule an actual 799 // fire event. 800 801 if ( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS ) 802 { 803 fire = true; 804 805 return nextTime + baselineTime - timeOfNextFire; 806 } 807 808 nextTime += deltaT; 809 } 810 811 // If the derivative is decreasing or if the algorithm has gone far 812 // enough, conclude that this neuron will not fire. 813 814 if ( deriv < 0 || nextTime > 1.0 ) 815 { 816 fire = false; 817 818 return -1.0; 819 } 820 821 // Otherwise, schedule a preliminary estimate fire event. 822 823 else 824 { 825 fire = false; 826 827 return nextTime + baselineTime - timeOfNextFire; 828 } 829 } 830 831 @Override 832 public double updateInput ( 833 final double time, 834 final Synapse input ) 835 //////////////////////////////////////////////////////////////////////// 836 { 837 /* 838 * I. Update spike time dependent plasticity variables. 839 * II. Update the appropriate state variables to the current time. 840 * III. Add the current from the input synapse into the appropriate 841 * state variables. 842 * IV. Schedule the next fire event. 843 */ 844 845 // Update STDP variables. 846 847 if ( histTable.containsKey ( input ) ) 848 { 849 // If this synapse is already in the table, just update its value. 850 851 TimeWeight tw = histTable.get ( input ); 852 853 if ( input.getType ( ) == 0 ) // Channel 0 has STDP. 854 { 855 if ( lastFireTime > tw.time ) // LTD only for close spikes. 856 { 857 // Update weight if neuron fired before. 858 859 tw.weight = tw.weight * ( 1 - para.Alpha_LTD * Math.exp ( 860 -para.K_LTD * ( time - lastFireTime ) ) ); 861 } 862 } 863 864 tw.time = time; // Update the time. 865 } 866 else // If this synapse did not fire before, add it to the table. 867 { 868 TimeWeight tw = new TimeWeight ( time, 1.0 ); 869 870 if ( input.getType ( ) == 0 ) // Channel 0 has STDP. 871 { 872 if ( lastFireTime > 0.0 ) 873 { 874 // Update weight if neuron fired before. 875 876 tw.weight = tw.weight * ( 1 - para.Alpha_LTD * Math.exp ( 877 -para.K_LTD * ( time - lastFireTime ) ) ); 878 } 879 } 880 881 // Put the default weight into the history table. 882 883 histTable.put ( input, tw ); 884 } 885 886 // Update the running average of mean time intervals in between spikes. 887 888 tAvg = tAvg * 0.8 + ( time - lastInputTime ) *0.2; 889 890 // Store the time of this spike. 891 892 lastInputTime = time; 893 894 // Update the G term to the current time. 895 896 sta_p [ State.NG ] [ 0 ].value *= Math.exp ( 897 -( time - sta_p [ State.NG ] [ 0 ].time ) 898 * para.allDecays [ State.NG ] [ 0 ]); 899 900 sta_p [ State.NG ] [ 0 ].time = time; 901 902 // Update the B term to the current time. 903 904 sta_p [ State.NB ] [ 0 ].value *= Math.exp ( 905 -( time - sta_p [ State.NB ] [ 0 ].time ) 906 * para.allDecays [ State.NB ] [ 0 ]); 907 908 sta_p [ State.NB ] [ 0 ].time = time; 909 910 // Update the appropriate current to the current time. 911 912 // If the appropriate current's channel is inactive, initialize its 913 // state variable. 914 915 int channel = input.getType ( ); 916 917 if ( sta_p [ State.NJ_SYNAPSE ] [ channel ] == null ) 918 { 919 sta_p [ State.NJ_SYNAPSE ] [ channel ] = 920 new State ( time, 0, State.NJ_SYNAPSE, channel); 921 922 state.add ( sta_p [ State.NJ_SYNAPSE ] [ channel ] ); 923 } 924 else 925 { 926 sta_p [ State.NJ_SYNAPSE ] [ channel ].value *= 927 Math.exp ( -( time - 928 sta_p [ State.NJ_SYNAPSE ] [ channel ].time ) 929 * para.allDecays [ State.NJ_SYNAPSE ] [ channel ] ); 930 931 sta_p [ State.NJ_SYNAPSE ] [ channel ].time = time; 932 } 933 934 // Add the input spike to the G term state variable. 935 936 sta_p [ State.NG ] [ 0 ].value += 937 input.getWeight ( ) * ( histTable.get ( input ).weight ) 938 * para.SYNAPSE_CURRENT_TO_NG [ channel ]; 939 940 // Add the input spike to the B term state variable. 941 942 sta_p [ State.NB ] [ 0 ].value += 943 input.getWeight ( ) * ( histTable.get ( input ).weight ) 944 * para.SYNAPSE_CURRENT_TO_NB [ channel ]; 945 946 // Add the input spike to the appropriate current's state variable. 947 948 sta_p [ State.NJ_SYNAPSE ] [ 0 ].value 949 += input.getWeight ( ) 950 * histTable.get ( input ).weight 951 * para.SYNAPSE_CURRENT_TO_SYNAPSE_NJ [ channel ]; 952 953 // The time to start looking for the next fire event. 954 955 double baselineTime; 956 957 // The voltage and threshold at baselineTime. 958 959 double nowVoltage = membraneVoltage(time); 960 961 double nowThreshold = nowVoltage - thresholdDiff(time); 962 963 // If this neuron is still within a refractory period, offset the 964 // membrane voltage and threshold such that by the end of the 965 // refractory period, the voltage will be equal to the clamp voltage. 966 967 if ( time < timeOfLastUpdate ) 968 { 969 970 // Calculate what the voltage and threshold will be 971 // at the end of the refractory period. 972 973 double futureVoltage = membraneVoltage ( timeOfLastUpdate ); 974 975 double futureThresholdGap = thresholdDiff ( timeOfLastUpdate ); 976 977 double futureThreshold = futureVoltage - futureThresholdGap; 978 979 // Calculate the necessary offsets. 980 981 double voltageOffset = clampVoltage - futureVoltage; 982 983 double thresholdOffset = clampThreshold - futureThreshold; 984 985 // Incorporate the offsets into the state variables. 986 987 double ngDecay = Math.exp ( ( timeOfLastUpdate - time ) 988 * para.allDecays [ State.NG ] [ 0 ] ); 989 990 double nbDecay = Math.exp ( ( timeOfLastUpdate - time ) 991 * para.allDecays [ State.NB ] [ 0 ] ); 992 993 sta_p [ State.NG ] [ 0 ].value += para.VOLTAGE_TO_NG * voltageOffset 994 * ngDecay; 995 996 sta_p [ State.NB ] [ 0 ].value += para.VOLTAGE_TO_NB * voltageOffset 997 * nbDecay; 998 999 sta_p [ State.NB ] [ 0 ].value += para.THRESHOLD_TO_NB 1000 * thresholdOffset * nbDecay; 1001 1002 nowVoltage = clampVoltage; 1003 1004 nowThreshold = clampThreshold; 1005 1006 baselineTime = timeOfLastUpdate; 1007 } 1008 else 1009 { 1010 clampVoltage = Double.MAX_VALUE; 1011 1012 clampThreshold = Double.MAX_VALUE; 1013 1014 timeOfLastUpdate = time; 1015 1016 baselineTime = time; 1017 } 1018 1019 // Schedule the next fire event, if one should exist. 1020 1021 // Use a modified Newton-Raphson (NR) root-finding iteration method to 1022 // find the first zero of the thresholdDiff function. 1023 1024 double deltaT, thresholdDiff; 1025 1026 double deriv = safeDerivative ( baselineTime ); 1027 1028 // Ever-increasing estimate of the time until the next fire event. 1029 1030 double nextTime = ( -( nowVoltage - nowThreshold ) / deriv ); 1031 1032 deltaT = nextTime; 1033 1034 // Repeat the NR iteration until either the derivative turns negative 1035 // or the predicted fire time grows so far away that an input spike 1036 // is likely to arrive in the intervening time, rendering any further 1037 // calculations useless (see D'Haene, 2009). 1038 1039 while ( deriv > 0 && nextTime < cost * tAvg ) 1040 { 1041 1042 // Safe derivative of the thresholdDiff function. 1043 1044 deriv = safeDerivative ( baselineTime + nextTime ); 1045 1046 // Value of the thresholdDiff function. 1047 1048 thresholdDiff = thresholdDiff ( baselineTime + nextTime ); 1049 1050 // As per the NR algorithm. 1051 1052 deltaT = ( -thresholdDiff / deriv ); 1053 1054 // If deltaT is within the proper precision, schedule an actual 1055 // fire event. 1056 1057 if ( deltaT > 0 ? deltaT < tEPS : -deltaT < tEPS ) 1058 { 1059 fire = true; 1060 1061 return ( time > timeOfLastUpdate 1062 ? nextTime : nextTime + timeOfLastUpdate - time ); 1063 } 1064 1065 nextTime += deltaT; 1066 } 1067 1068 // If the derivative is decreasing or if the algorithm has gone far 1069 // enough, conclude that this neuron will not fire. 1070 1071 if( deriv < 0 || nextTime > 1.0 ) 1072 { 1073 fire = false; 1074 1075 return -1.0; 1076 } 1077 1078 // Otherwise, schedule a preliminary estimate fire event. 1079 1080 else 1081 { 1082 fire = false; 1083 1084 return ( time > timeOfLastUpdate 1085 ? nextTime : nextTime + timeOfLastUpdate - time ); 1086 } 1087 } 1088 1089 1090 //////////////////////////////////////////////////////////////////////// 1091 // accessor methods 1092 //////////////////////////////////////////////////////////////////////// 1093 1094 public double getSensoryWeight ( ) 1095 //////////////////////////////////////////////////////////////////////// 1096 { 1097 throw new RuntimeException ( "This neuron type doesn't use the " 1098 + "Sensory Weight Functions!" ); 1099 } 1100 1101 public double getTimeOfLastUpdate ( ) 1102 //////////////////////////////////////////////////////////////////////// 1103 { 1104 return this.timeOfLastUpdate; 1105 } 1106 1107 1108 //////////////////////////////////////////////////////////////////////// 1109 // mutator methods 1110 //////////////////////////////////////////////////////////////////////// 1111 1112 /*********************************************************************** 1113 * set the membrane voltage 1114 ***********************************************************************/ 1115 public void setMemV ( double memV ) 1116 //////////////////////////////////////////////////////////////////////// 1117 { 1118 return; 1119 } 1120 1121 public void setTimeOfLastUpdate ( final double timeOfLastUpdate ) 1122 //////////////////////////////////////////////////////////////////////// 1123 { 1124 this.timeOfLastUpdate = timeOfLastUpdate; 1125 } 1126 1127 //////////////////////////////////////////////////////////////////////// 1128 // overridden Object methods 1129 //////////////////////////////////////////////////////////////////////// 1130 1131 @Override 1132 public String toString ( ) 1133 //////////////////////////////////////////////////////////////////////// 1134 { 1135 String tmp=""; 1136 1137 tmp=tmp+"Current:"+"\n"; 1138 1139 return tmp; 1140 } 1141 1142 //////////////////////////////////////////////////////////////////////// 1143 // miscellaneous methods 1144 //////////////////////////////////////////////////////////////////////// 1145 1146 1147 /*********************************************************************** 1148 * Compute the membrane voltage. 1149 * 1150 * This method transforms the thresholdDiff (Mihalas and Niebur, 2009, 1151 * Equation 3.5) into the voltage ( V(t) in Equation 3.2) 1152 * 1153 * @param t 1154 * absolute time t 1155 * @return membrane voltage at time t.s 1156 **********************************************************************/ 1157 public double membraneVoltage ( final double t ) 1158 //////////////////////////////////////////////////////////////////////// 1159 { 1160 1161 double V = 0; 1162 1163 double constant = 0; 1164 1165 Iterator<State> iter = state.iterator ( ); 1166 1167 State first = iter.next ( ); 1168 1169 State second; 1170 1171 double firstV, secondV; 1172 1173 if ( first.type == State.NB ) // Ignore the NB term. 1174 { 1175 if ( iter.hasNext ( ) ) 1176 { 1177 first = iter.next ( ); 1178 } 1179 else 1180 { 1181 return V + constant + para.MEMBRANE_VOLTAGE_BASE + para.VREST; 1182 } 1183 } 1184 1185 if ( first.time == t ) // Ignore no time change term. 1186 // TODO: what does that mean? 1187 { 1188 if ( first.type == State.NB ) 1189 { 1190 constant += first.value / para.VOLTAGE_TO_NG; 1191 } 1192 else if (first.type == State.NJ_SPIKE) 1193 { 1194 constant += first.value 1195 / para.VOLTAGE_TO_SPIKE_NJ [ first.index ]; 1196 } 1197 else if (first.type == State.NJ_SYNAPSE) 1198 { 1199 constant += first.value 1200 / para.VOLTAGE_TO_SYNAPSE_NJ [ first.index ]; 1201 } 1202 1203 if ( iter.hasNext ( ) ) 1204 { 1205 first = iter.next ( ); 1206 } 1207 else 1208 { 1209 return V + constant + para.MEMBRANE_VOLTAGE_BASE + para.VREST; 1210 } 1211 } 1212 1213 if ( first.type == State.NB ) // Ignore the NB term. 1214 { 1215 if ( iter.hasNext ( ) ) 1216 { 1217 first = iter.next ( ); 1218 } 1219 else 1220 { 1221 return V + constant + para.MEMBRANE_VOLTAGE_BASE + para.VREST; 1222 } 1223 } 1224 1225 if ( first.type == State.NG ) 1226 { 1227 firstV = first.value / para.VOLTAGE_TO_NG; 1228 } 1229 else if (first.type == State.NJ_SPIKE) 1230 { 1231 firstV = first.value / para.VOLTAGE_TO_SPIKE_NJ [ first.index ]; 1232 } 1233 else 1234 { 1235 firstV = first.value / para.VOLTAGE_TO_SYNAPSE_NJ [ first.index ]; 1236 } 1237 1238 V = firstV; 1239 1240 while ( iter.hasNext ( ) ) 1241 { 1242 second = iter.next ( ); 1243 1244 secondV = 0.0; 1245 1246 if ( second.time == t && second.type != State.NB ) 1247 { 1248 if ( second.type == State.NG ) 1249 { 1250 constant += second.value / para.VOLTAGE_TO_NG; 1251 } 1252 else if (second.type == State.NJ_SPIKE) 1253 { 1254 constant += second.value 1255 / para.VOLTAGE_TO_SPIKE_NJ [ second.index ]; 1256 } 1257 else 1258 { 1259 constant += second.value 1260 / para.VOLTAGE_TO_SYNAPSE_NJ [ second.index ]; 1261 } 1262 1263 if ( iter.hasNext ( ) ) 1264 { 1265 second = iter.next ( ); 1266 } 1267 else 1268 { 1269 break; 1270 } 1271 } 1272 1273 if ( second.type == State.NB ) 1274 { 1275 if ( iter.hasNext ( ) ) 1276 { 1277 second = iter.next ( ); 1278 } 1279 else 1280 { 1281 break; 1282 } 1283 } 1284 1285 if ( second.type == State.NG ) 1286 { 1287 secondV = second.value / para.VOLTAGE_TO_NG; 1288 } 1289 else if (second.type == State.NJ_SPIKE) 1290 { 1291 secondV = second.value 1292 / para.VOLTAGE_TO_SPIKE_NJ [ second.index ]; 1293 } 1294 else 1295 { 1296 secondV = second.value 1297 / para.VOLTAGE_TO_SYNAPSE_NJ [ second.index ]; 1298 } 1299 1300 V = Math.exp ( -( ( para.allDecays [ first.type ] [ first.index ] 1301 - para.allDecays [ second.type ] [ second.index ] ) * t 1302 - ( para.allDecays [ first.type ] [ first.index ] * first.time 1303 - para.allDecays [ second.type ] [ second.index ] * second.time ) 1304 ) ) * V + secondV; 1305 1306 first = second; 1307 } 1308 1309 V = Math.exp ( -( para.allDecays [ first.type ] [ first.index ] * t 1310 - para.allDecays [ first.type ] [ first.index ] * first.time )) * V; 1311 1312 return V + para.MEMBRANE_VOLTAGE_BASE + para.VREST + constant; 1313 } 1314 1315 /*********************************************************************** 1316 * Compute the difference between the membrane voltage and the membrane 1317 * voltage threshold. The <code>MNNeuron</code> class is optimized to 1318 * perform this function quickly, at the expense of 1319 * <code>membraneVoltage()</code>. 1320 * 1321 * This method is based on Equation 3.5 in Mihalas and Niebur, 2009. 1322 * 1323 * @param t 1324 * absolute time t 1325 * @return voltage - threshold at time t 1326 **********************************************************************/ 1327 public double thresholdDiff ( double t ) 1328 //////////////////////////////////////////////////////////////////////// 1329 { 1330 double V = 0, constant = 0.0; 1331 1332 Iterator<State> iter = state.iterator ( ); 1333 1334 State first = iter.next ( ); 1335 State second; 1336 1337 double firstV, secondV; 1338 1339 // If the state is a current channel and its value is too small, delete 1340 // it. 1341 1342 if ( Math.abs ( first.value ) < rEPS 1343 && ( int ) first.type != State.NG 1344 && ( int ) first.type != State.NB ) 1345 { 1346 iter.remove ( ); 1347 1348 sta_p [ first.type ] [ first.index ] = null; 1349 } 1350 1351 // ignore no time change term 1352 1353 if ( first.time == t ) 1354 { 1355 constant += first.value; 1356 1357 if ( iter.hasNext ( ) ) 1358 { 1359 first = iter.next ( ); 1360 } 1361 else 1362 { 1363 return V + constant + para.THRESHOLD_DIFF_BASE; 1364 } 1365 } 1366 1367 firstV = first.value; 1368 1369 V = firstV; 1370 1371 while ( iter.hasNext ( ) ) 1372 { 1373 second = iter.next ( ); 1374 1375 // If the state is a current channel and its value is too small, 1376 // delete it. 1377 1378 if ( Math.abs ( second.value ) < rEPS 1379 && second.type != State.NG 1380 && second.type != State.NB ) 1381 { 1382 iter.remove ( ); 1383 1384 sta_p [ second.type ] [ second.index ] = null; 1385 } 1386 1387 secondV = 0.0; 1388 1389 // Ignore no time change term. 1390 1391 if ( second.time == t ) 1392 { 1393 constant += second.value; 1394 1395 if ( iter.hasNext ( ) ) 1396 { 1397 second = iter.next ( ); 1398 } 1399 else 1400 { 1401 break; 1402 } 1403 } 1404 1405 secondV = second.value; 1406 1407 V = Math.exp ( -( ( para.allDecays [ first.type ] [ first.index ] 1408 - para.allDecays [ second.type ] [ second.index ] ) * t 1409 - ( para.allDecays [ first.type ] [ first.index ] * first.time 1410 - para.allDecays [ second.type ] [ second.index ] * second.time ) 1411 ) ) * V + secondV; 1412 1413 first = second; 1414 } 1415 1416 V = Math.exp ( -( para.allDecays [ first.type ] [ first.index ] * t 1417 - para.allDecays [ first.type ] [ first.index ] * first.time )) * V; 1418 1419 return V + para.THRESHOLD_DIFF_BASE + constant; 1420 } 1421 1422 /*********************************************************************** 1423 * Computes the "safe derivative" (see D'Haene, 2009) of the 1424 * thresholdDiff function. 1425 * 1426 * @param t 1427 * absolute time t 1428 ***********************************************************************/ 1429 public double safeDerivative ( double t ) 1430 //////////////////////////////////////////////////////////////////////// 1431 { 1432 // The smallest inverse decay of all negative state variables. 1433 1434 double tauSafe = Double.MAX_VALUE; 1435 1436 // Find the tauSafe. 1437 1438 Iterator<State> iter = state.descendingIterator(); 1439 1440 while(iter.hasNext()) 1441 { 1442 State tmpState = iter.next(); 1443 1444 if ( tmpState.value < 0 ) 1445 { 1446 tauSafe = para.allDecays [ tmpState.type ] [ tmpState.index]; 1447 1448 break; 1449 } 1450 } 1451 1452 double constant = 0; 1453 1454 double V = 0; 1455 1456 iter = state.iterator ( ); 1457 1458 State first = iter.next ( ); 1459 1460 State second; 1461 1462 double firstV, secondV; 1463 1464 if ( first.time == t ) 1465 { 1466 // Ignore no time change term. 1467 1468 if ( first.value > 0 ) 1469 { 1470 final double maxDecay = Math.min ( 1471 para.allDecays [ first.type ] [ first.index ], 1472 tauSafe ); 1473 1474 constant += -first.value * maxDecay; 1475 } 1476 else 1477 { 1478 constant += 1479 -first.value * para.allDecays [ first.type ] [ first.index ]; 1480 } 1481 1482 if ( iter.hasNext ( ) ) 1483 { 1484 first = iter.next ( ); 1485 } 1486 else 1487 { 1488 return V + constant; 1489 } 1490 } 1491 1492 if ( first.value > 0 ) 1493 { 1494 final double maxDecay = Math.min ( 1495 para.allDecays [ first.type ] [ first.index ], 1496 tauSafe ); 1497 1498 firstV = -first.value * maxDecay; 1499 } 1500 else 1501 { 1502 firstV = 1503 -first.value * para.allDecays [ first.type ] [ first.index ]; 1504 } 1505 1506 V = firstV; 1507 1508 while ( iter.hasNext ( ) ) 1509 { 1510 second = iter.next ( ); 1511 1512 secondV = 0.0; 1513 1514 if( second.time == t) 1515 { 1516 if ( second.value > 0 ) 1517 { 1518 final double maxDecay = Math.min ( 1519 para.allDecays [ second.type ] [ second.index ], 1520 tauSafe ); 1521 1522 constant += -second.value * maxDecay; 1523 } 1524 else 1525 { 1526 constant 1527 += -second.value 1528 * para.allDecays [ second.type ] [ second.index ]; 1529 } 1530 1531 if ( iter.hasNext ( ) ) 1532 { 1533 second = iter.next ( ); 1534 } 1535 else 1536 { 1537 break; 1538 } 1539 } 1540 1541 if ( second.value > 0 ) 1542 { 1543 final double maxDecay = Math.min ( 1544 para.allDecays [ second.type ] [ second.index ], 1545 tauSafe ); 1546 1547 secondV = -second.value * maxDecay; 1548 } 1549 else 1550 { 1551 secondV = 1552 -second.value * para.allDecays [ second.type] [ second.index]; 1553 } 1554 1555 V = Math.exp ( -( ( para.allDecays [ first.type ] [ first.index] 1556 - para.allDecays [ second.type ] [ second.index] ) * t 1557 - ( para.allDecays [ first.type ] [ first.index ] * first.time 1558 - para.allDecays [ second.type ] [ second.index] * second.time ) 1559 ) ) * V + secondV; 1560 1561 first = second; 1562 } 1563 1564 V = Math.exp ( -( para.allDecays [first.type ] [ first.index ] * t 1565 - para.allDecays [ first.type ] [ first.index ] * first.time )) * V; 1566 1567 return V + constant; 1568 } 1569 1570 //////////////////////////////////////////////////////////////////////// 1571 //////////////////////////////////////////////////////////////////////// 1572 }