CARLsim  3.1.3
CARLsim: a GPU-accelerated SNN simulator
simple_weight_tuner.cpp
Go to the documentation of this file.
1 #include "simple_weight_tuner.h"
2 
3 #include <carlsim.h> // CARLsim, SpikeMonitor
4 #include <math.h> // fabs
5 #include <stdio.h> // printf
6 #include <limits> // double::max
7 #include <assert.h> // assert
8 
9 // ****************************************************************************************************************** //
10 // SIMPLEWEIGHTTUNER UTILITY PRIVATE IMPLEMENTATION
11 // ****************************************************************************************************************** //
12 
21 public:
22  // +++++ PUBLIC METHODS +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ //
23 
24  Impl(CARLsim *sim, double errorMargin, int maxIter, double stepSizeFraction) {
25  assert(sim!=NULL);
26  assert(errorMargin>0);
27  assert(maxIter>0);
28  assert(stepSizeFraction>0.0f && stepSizeFraction<=1.0f);
29 
30  sim_ = sim;
31  assert(sim_->getCARLsimState()!=RUN_STATE);
32 
33  errorMargin_ = errorMargin;
34  stepSizeFraction_ = stepSizeFraction;
35  maxIter_ = maxIter;
36 
37  connId_ = -1;
38  wtRange_ = NULL;
39  wtInit_ = -1.0;
40 
41  grpId_ = -1;
42  targetRate_ = -1.0;
43 
44  wtStepSize_ = -1.0;
45  cntIter_ = 0;
46 
47  wtShouldIncrease_ = true;
48  adjustRange_ = true;
49 
50  needToInitConnection_ = true;
51  needToInitTargetFiring_ = true;
52 
53  needToInitAlgo_ = true;
54  }
55 
56  ~Impl() {
57  if (wtRange_!=NULL)
58  delete wtRange_;
59  wtRange_=NULL;
60  }
61 
62 // user function to reset algo
63 void reset() {
64  needToInitAlgo_ = true;
65  initAlgo();
66 }
67 
68 bool done(bool printMessage) {
69  // algo not initalized: we're not done
70  if (needToInitConnection_ || needToInitTargetFiring_ || needToInitAlgo_)
71  return false;
72 
73  // success: margin reached
74  if (fabs(currentError_) < errorMargin_) {
75  if (printMessage) {
76  printf("SimpleWeightTuner successful: Error margin reached in %d iterations.\n",cntIter_);
77  }
78  return true;
79  }
80 
81  // failure: max iter reached
82  if (cntIter_ >= maxIter_) {
83  if (printMessage) {
84  printf("SimpleWeightTuner failed: Max number of iterations (%d) reached.\n",maxIter_);
85  }
86  return true;
87  }
88 
89  // else we're not done
90  return false;
91 }
92 
93 void setConnectionToTune(short int connId, double initWt, bool adjustRange) {
94  assert(connId>=0 && connId<sim_->getNumConnections());
95 
96  connId_ = connId;
97  wtInit_ = initWt;
98  adjustRange_ = adjustRange;
99 
100  needToInitConnection_ = false;
101  needToInitAlgo_ = true;
102 }
103 
104 void setTargetFiringRate(int grpId, double targetRate) {
105  grpId_ = grpId;
106  targetRate_ = targetRate;
107  currentError_ = targetRate;
108 
109  // check whether group has SpikeMonitor
110  SM_ = sim_->getSpikeMonitor(grpId);
111  if (SM_==NULL) {
112  // setSpikeMonitor has not been called yet
113  SM_ = sim_->setSpikeMonitor(grpId,"NULL");
114  }
115 
116  needToInitTargetFiring_ = false;
117  needToInitAlgo_ = true;
118 }
119 
120 void iterate(int runDurationMs, bool printStatus) {
121  assert(runDurationMs>0);
122 
123  // if we're done, don't iterate
124  if (done(printStatus)) {
125  return;
126  }
127 
128  // make sure we have initialized algo
129  assert(!needToInitConnection_);
130  assert(!needToInitTargetFiring_);
131  if (needToInitAlgo_)
132  initAlgo();
133 
134  // in case the user has already been messing with the SpikeMonitor, we need to make sure that
135  // PersistentMode is off
136  SM_->setPersistentData(false);
137 
138  // now iterate
139  SM_->startRecording();
140  sim_->runNetwork(runDurationMs/1000, runDurationMs%1000, false);
141  SM_->stopRecording();
142 
143  double thisRate = SM_->getPopMeanFiringRate();
144  if (printStatus) {
145  printf("#%d: rate=%.4fHz, target=%.4fHz, error=%.7f, errorMargin=%.7f\n", cntIter_, thisRate, targetRate_,
146  thisRate-targetRate_, errorMargin_);
147  }
148 
149  currentError_ = thisRate - targetRate_;
150  cntIter_++;
151 
152  // check if we're done now
153  if (done(printStatus)) {
154  return;
155  }
156 
157  // else update parameters
158  if ((wtStepSize_>0 && thisRate>targetRate_) || (wtStepSize_<0 && thisRate<targetRate_)) {
159  // we stepped too far to the right or too far to the left
160  // turn around and cut step size in half
161  // note that this should work for inhibitory connections, too: they have negative weights, so adding
162  // to the weight will actually decrease it (make it less negative)
163  wtStepSize_ = -wtStepSize_/2.0;
164  }
165 
166  // find new weight
167  sim_->biasWeights(connId_, wtStepSize_, adjustRange_);
168 }
169 
170 private:
171  // +++++ PRIVATE METHODS ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ //
172 
173 // need to call this whenever connection or target firing changes
174 // or when user calls reset
175 void initAlgo() {
176  if (!needToInitAlgo_)
177  return;
178 
179  // make sure we have all the data structures we need
180  assert(!needToInitConnection_);
181  assert(!needToInitTargetFiring_);
182 
183  // update weight ranges
184  RangeWeight wt = sim_->getWeightRange(connId_);
185  wtRange_ = new RangeWeight(wt.min, wt.init, wt.max);
186 
187  // reset algo
188  wtShouldIncrease_ = true;
189  wtStepSize_ = stepSizeFraction_ * (wtRange_->max - wtRange_->min);
190 #if defined(WIN32) || defined(WIN64)
191  currentError_ = DBL_MAX;
192 #else
193  currentError_ = std::numeric_limits<double>::max();
194 #endif
195 
196  // make sure we're in the right CARLsim state
197  if (sim_->getCARLsimState()!=RUN_STATE)
198  sim_->runNetwork(0,0,false);
199 
200  // initialize weights
201  if (wtInit_>=0) {
202  // start at some specified initWt
203  if (wt.init != wtInit_) {
204  // specified starting point is not what is specified in connect
205 
206  sim_->biasWeights(connId_, wtInit_ - wt.init, adjustRange_);
207  }
208  }
209 
210  needToInitAlgo_ = false;
211 }
212 
213 
214  // +++++ PRIVATE STATIC PROPERTIES ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ //
215 
216  // +++++ PRIVATE PROPERTIES +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ //
217 
218  // flags that manage state
219  bool needToInitConnection_;
220  bool needToInitTargetFiring_;
221  bool needToInitAlgo_;
222 
223  // CARLsim data structures
224  CARLsim *sim_;
225  SpikeMonitor *SM_;
226  int grpId_;
227  short int connId_;
228  RangeWeight* wtRange_;
229 
230  // termination condition params
231  int maxIter_;
232  double errorMargin_;
233  double targetRate_;
234 
235  // params that are updated every iteration step
236  int cntIter_;
237  double wtStepSize_;
238  bool wtShouldIncrease_;
239  double currentError_;
240 
241  // options
242  bool adjustRange_;
243  double wtInit_;
244  double stepSizeFraction_;
245 };
246 
247 
248 // ****************************************************************************************************************** //
249 // SIMPLEWEIGHTTUNER API IMPLEMENTATION
250 // ****************************************************************************************************************** //
251 
252 // create and destroy a pImpl instance
253 SimpleWeightTuner::SimpleWeightTuner(CARLsim* sim, double errorMargin, int maxIter, double stepSizeFraction) :
254  _impl( new Impl(sim, errorMargin, maxIter, stepSizeFraction) ) {}
256 
257 void SimpleWeightTuner::setConnectionToTune(short int connId, double initWt, bool adjustRange) {
258  _impl->setConnectionToTune(connId, initWt, adjustRange);
259 }
260 void SimpleWeightTuner::setTargetFiringRate(int grpId, double targetRate) {
261  _impl->setTargetFiringRate(grpId, targetRate);
262 }
263 void SimpleWeightTuner::iterate(int runDurationMs, bool printStatus) { _impl->iterate(runDurationMs, printStatus); }
264 bool SimpleWeightTuner::done(bool printMessage) { return _impl->done(printMessage); }
265 void SimpleWeightTuner::reset() { _impl->reset(); }
void startRecording()
Starts a new recording period.
carlsimState_t getCARLsimState()
Returns the current CARLsim state.
Definition: carlsim.h:1259
void setTargetFiringRate(int grpId, double targetRate)
CARLsim User Interface This class provides a user interface to the public sections of CARLsimCore sou...
Definition: carlsim.h:141
float getPopMeanFiringRate()
Returns the mean firing rate of the entire neuronal population.
void setConnectionToTune(short int connId, double initWt=-1.0, bool adjustRange=true)
Sets up the connection to tune.
~SimpleWeightTuner()
Destructor.
void setConnectionToTune(short int connId, double initWt, bool adjustRange)
Class SpikeMonitor.
SimpleWeightTuner(CARLsim *sim, double errorMargin=1e-3, int maxIter=100, double stepSizeFraction=0.5)
Creates a new instance of class SimpleWeightTuner.
void stopRecording()
Ends a recording period.
void reset()
Resets the algorithm to initial conditions.
bool done(bool printMessage=false)
Determines whether a termination criterion has been met.
int runNetwork(int nSec, int nMsec=0, bool printRunSummary=true, bool copyState=false)
run the simulation for time=(nSec*seconds + nMsec*milliseconds)
void iterate(int runDurationMs, bool printStatus)
void setTargetFiringRate(int grpId, double targetRate)
Sets up the target firing rate of a specific group.
void iterate(int runDurationMs=1000, bool printStatus=true)
Performs an iteration step of the tuning algorithm.
void biasWeights(short int connId, float bias, bool updateWeightRange=false)
Adds a constant bias to the weight of every synapse in the connection.
bool done(bool printMessage)
run state, where the model is stepped
a range struct for synaptic weight magnitudes
Private implementation of the Stopwatch Utility.
SpikeMonitor * setSpikeMonitor(int grpId, const std::string &fileName)
Sets a Spike Monitor for a groups, prints spikes to binary file.
SpikeMonitor * getSpikeMonitor(int grpId)
returns pointer to previously allocated SpikeMonitor object, NULL else
void setPersistentData(bool persistentData)
Sets PersistentMode either on (true) or off (false)
Impl(CARLsim *sim, double errorMargin, int maxIter, double stepSizeFraction)
RangeWeight getWeightRange(short int connId)
returns the RangeWeight struct for a specific connection ID