本文整理汇总了Python中emcee.utils.MPIPool类的典型用法代码示例。如果您正苦于以下问题:Python MPIPool类的具体用法?Python MPIPool怎么用?Python MPIPool使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了MPIPool类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_pool
def get_pool(mpi=False, threads=None):
""" Get a pool object to pass to emcee for parallel processing.
If mpi is False and threads is None, pool is None.
Parameters
----------
mpi : bool
Use MPI or not. If specified, ignores the threads kwarg.
threads : int (optional)
If mpi is False and threads is specified, use a Python
multiprocessing pool with the specified number of threads.
"""
if mpi:
from emcee.utils import MPIPool
# Initialize the MPI pool
pool = MPIPool()
# Make sure the thread we're running on is the master
if not pool.is_master():
pool.wait()
sys.exit(0)
print("Running with MPI...")
elif threads > 1:
import multiprocessing
print("Running with multiprocessing on " + str(threads) + " cores...")
pool = multiprocessing.Pool(threads)
else:
print("Running serial...")
pool = None
return pool
示例2: pt_mpi_sample
def pt_mpi_sample(gf, ntemps, nwalkers, burn_steps, sample_steps, thin=1,
pool=None, betas=None, pos=None, random_state=None,
pos_filename=None, convergence_interval=50):
pool = MPIPool(loadbalance=True)
if not pool.is_master():
pool.wait()
sys.exit(0)
return pt_sample(gf, ntemps, nwalkers, burn_steps, sample_steps,
thin=thin, pool=pool, betas=betas, pos=pos,
random_state=random_state, pos_filename=pos_filename,
convergence_interval=convergence_interval)
示例3: fit_bim_bh3_curves
def fit_bim_bh3_curves(p0=None):
# Choose initial position
if p0 is None:
p0 = np.zeros((nwalkers, ndim))
for walk_ix in range(nwalkers):
for d_ix in range(len(data)):
p0[walk_ix, d_ix*3] = np.random.uniform(1, 6)
p0[walk_ix, d_ix*3 + 1] = np.random.uniform(6e-5, 1e-3)
p0[walk_ix, d_ix*3 + 2] = np.random.uniform(2, 3)
hp_ix = len(data)*3
p0[walk_ix, hp_ix] = np.random.uniform(1,6) # fmax mean
p0[walk_ix, hp_ix + 1] = np.random.uniform(0,1) # fmax sd
p0[walk_ix, hp_ix + 2] = np.random.uniform(6e-5, 1e-3) # k mean
p0[walk_ix, hp_ix + 3] = np.random.uniform(0,1e-1) # k sd
p0[walk_ix, hp_ix + 4] = np.random.uniform(2,3) # f0 mean
p0[walk_ix, hp_ix + 5] = np.random.uniform(0,1) # f0 sd
#plt.figure()
#for d_ix, data_i in enumerate(data):
# plt.plot(time, data_i, color=colors[d_ix])
# plt.plot(time, fit_func(p0[0, d_ix*3:(d_ix+1)*3]), color='k')
# Initialize the MPI pool
pool = MPIPool()
if not pool.is_master():
pool.wait()
sys.exit(0)
# Get the sampler
sampler = emcee.EnsembleSampler(nwalkers, ndim, posterior, pool=pool)
# Burn-in
print("Burn-in sampling...")
pos, prob, state = sampler.run_mcmc(p0, burn_steps, storechain=False)
sampler.reset()
# Main sampling
print("Main sampling...")
sampler.run_mcmc(pos, sample_steps)
# Close the pool!
pool.close()
# Pickle the sampler
sampler.pool = None
with open('bimbh3_141125_2.pck','w') as f:
pickle.dump(sampler, f)
return sampler
示例4: tdelay_dt_mcmc
def tdelay_dt_mcmc(run, theta, Niter=20, Nwalkers=10, Ndim=2, sigma_smhm=0.2, nsnap0=15, downsampled='14', flag=None, continue_chain=False):
'''
'''
if Ndim == 2:
tdelay_range = [0., 3.]#np.arange(0., 3., 0.5)
dt_range = [0.1, 4.]
# new chain
chain_file = ''.join([UT.fig_dir(), run, '.tdelay_dt_mcmc.chain.dat'])
if os.path.isfile(chain_file) and continue_chain:
print 'Continuing previous MCMC chain!'
sample = np.loadtxt(chain_file)
Niter = Niter - (np.float(len(sample))/np.float(Nwalkers)) # Number of chains left to finish
if Niter <= 0:
raise ValueError
print Niter, ' iterations left to finish'
else:
f = open(chain_file, 'w')
f.close()
# Initializing Walkers
pos0 = [np.array([np.random.uniform(tdelay_range[0], tdelay_range[1]), np.random.uniform(dt_range[0], dt_range[1])]) for i in range(Nwalkers)]
pool = MPIPool()
if not pool.is_master():
pool.wait()
sys.exit(0)
# Initializing the emcee sampler
kwargs = {
'theta': theta,
'sigma_smhm': 0.2,
'nsnap0': 15,
'downsampled': '14',
}
sampler = emcee.EnsembleSampler(Nwalkers, Ndim, sigM, pool=pool, kwargs=kwargs)
for result in sampler.sample(pos0, iterations=Niter, storechain=False):
position = result[0]
#print position
f = open(chain_file, 'a')
for k in range(position.shape[0]):
output_str = '\t'.join(position[k].astype('str')) + '\n'
f.write(output_str)
f.close()
pool.close()
return None
示例5: run_pool
def run_pool(pC, pW, walk, step): #pCenter and pWidths
steps = step
nwalkers = walk
ndim = len(pC)
## r in, del r, i, PA
p0 = [pC[0], pC[1], pC[2], pC[3], pC[4]]
widths = [pW[0], pW[1], pW[2], pW[3], pW[4]]
p = emcee.utils.sample_ball(p0,widths,size=nwalkers)
pool = MPIPool()
if not pool.is_master():
pool.wait()
sys.exit(0)
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike_visonly, live_dangerously=True, pool=pool)
print 'Beginning the MCMC run.'
start = time.clock()
sampler.run_mcmc(p, steps)
stop = time.clock()
pool.close()
print 'MCMC finished successfully.\n'
print 'This was a visibility-only run with {} walkers and {} steps'.format(nwalkers,steps)
print "Mean acor time: "+str(np.mean(sampler.acor))
print "\nMean acceptance fraction: "+str(np.mean(sampler.acceptance_fraction))
print 'Run took %r minutes' % ((stop - start)/60.)
chain = sampler.chain
chi = (sampler.lnprobability)/(-0.5)
whatbywhat = str(nwalkers)+'x'+str(steps)
os.system('mkdir MCMCRUNS/vis_only/'+whatbywhat)
chainFile = 'MCMCRUNS/vis_only/'+whatbywhat+'/'+whatbywhat+'.chain.fits'
chiFile = 'MCMCRUNS/vis_only/'+whatbywhat+'/'+whatbywhat+'.chi.fits'
infoFile = 'MCMCRUNS/vis_only/'+whatbywhat+'/'+whatbywhat+'.runInfo.txt'
fits.writeto(chainFile,chain)
fits.writeto(chiFile,chi)
f = open(infoFile,'w')
f.write('run took %r minutes\n' % ((stop - start)/60.))
f.write('walkers: %r\n' % nwalkers)
f.write('steps: %r\n' % steps)
f.write('initial model: %r\n' % p0)
f.write('widths: %r\n' % widths)
f.write("mean acor time: "+str(np.mean(sampler.acor)))
f.write("\nmean acceptance fraction: "+str(np.mean(sampler.acceptance_fraction)))
f.close()
print 'Data written to: \n'+chainFile+'\n'+chiFile+'\n'+infoFile
示例6: lnPost
def lnPost(theta):
'''log-posterior
'''
# prior calculations
if prior_min[0] < theta[0] < prior_max[0] and \
prior_min[1] < theta[1] < prior_max[1] and \
prior_min[2] < theta[2] < prior_max[2] and \
prior_min[3] < theta[3] < prior_max[3] and \
prior_min[4] < theta[4] < prior_max[4]:
lnPrior = 0.0
else:
lnPrior = -np.inf
if not np.isfinite(lnPrior):
return -np.inf
return lnPrior + lnLike(theta)
"""Initializing Walkers"""
pos = [np.array([11. , np.log(.4) , 11.5 , 1.0 , 13.5]) + 1e-3*np.random.randn(Ndim) for i in range(Nwalkers)]
"""Initializing MPIPool"""
pool = MPIPool(loadbalance=True)
if not pool.is_master():
pool.wait()
sys.exit(0)
"""Initializing the emcee sampler"""
sampler = emcee.EnsembleSampler(Nwalkers, Ndim, lnprob, pool=pool)
# Burn in + Production
sampler.run_mcmc(pos, Nchains_burn + Nchains_pro)
# Production.
samples = sampler.chain[:, Nchains_burn:, :].reshape((-1, Ndim))
#closing the pool
pool.close()
np.savetxt("mcmc_sample.dat" , samples)
示例7: main
def main():
'''
A parallel run.
'''
pool = MPIPool(loadbalance=True)
if not pool.is_master():
pool.wait()
sys.exit(0)
clf = hbsgc.HBSGC(pool=pool)
# save start time
clf.last_clock = time.clock()
clf.filter_calcs()
clf.data_calcs()
clf.star_model_calcs()
# if clf.calc_model_mags:
# clf.star_model_mags()
clf.gal_model_calcs()
# if clf.calc_model_mags:
# clf.gal_model_mags()
clf.fit_calcs()
clf.count_tot = 0
clf.sample()
clf.save_proba()
if clf.min_chi2_write:
clf.save_min_chi2()
pool.close()
示例8: ens_mpi_sample
def ens_mpi_sample(gf, nwalkers, burn_steps, sample_steps, pos=None,
random_state=None):
pool = MPIPool(loadbalance=True)
if not pool.is_master():
pool.wait()
sys.exit(0)
# Initialize the parameter array with initial values (in log10 units)
# Number of parameters to estimate
ndim = (len(gf.builder.global_params) +
(len(gf.data) * len(gf.builder.local_params)))
# Initialize the walkers with starting positions drawn from the priors
# Note that the priors are in log10 scale already, so they don't
# need to be transformed here
if pos is None:
p0 = np.zeros((nwalkers, ndim))
for walk_ix in range(nwalkers):
for p_ix in range(ndim):
p0[walk_ix, p_ix] = gf.priors[p_ix].random()
else:
p0 = pos
# Create the sampler object
sampler = emcee.EnsembleSampler(nwalkers, ndim, posterior,
args=[gf], pool=pool)
if random_state is not None:
sampler.random_state = random_state
print "Burn in sampling..."
pos, prob, state = sampler.run_mcmc(p0, burn_steps, storechain=False)
sampler.reset()
print "Main sampling..."
sampler.run_mcmc(pos, sample_steps)
# Close the pool!
pool.close()
print "Done sampling."
return sampler
示例9: MPIPool
cmd_args = parser.parse_args()
if cmd_args.options_file is None:
parser.print_help()
sys.exit(0)
#Set verbosity level
if cmd_args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
#Initialize MPIPool
try:
pool = MPIPool()
except:
pool = None
if (pool is not None) and not(pool.is_master()):
pool.wait()
pool.comm.Barrier()
MPI.Finalize()
sys.exit(0)
#Set progressbar attributes
widgets = ["Progress: ",progressbar.Percentage(),' ',progressbar.Bar(marker="+")]
#Parse INI options file
options = ConfigParser.ConfigParser()
示例10: KSmap_single
'''grab KSmap, create 1, 1.8, 3.5, 5.7 peaks
each time create 1000 realizations for specifice i, cosmo, sigmaG
'''
i, R, cosmo, sigmaG = iiRcosmoSigma
kmap = KSmap_single(i, R, cosmo, sigmaG)
idx = int(where(sigmaG_arr==sigmaG)[0])
mask_all = mask_all_arr[idx]
mask_bad = mask_bad_arr[idx]
pspk_all = WLanalysis.peaks_mask_hist(kmap, mask_all, bins, kmin = kmin, kmax = kmax)
pspk_bad = WLanalysis.peaks_mask_hist(kmap, mask_bad, bins, kmin = kmin, kmax = kmax)
return pspk_all, pspk_bad
###############################################################
######## operations ###########################################
###############################################################
pool = MPIPool()
####### (0)test corrupted SIM file ###########
#def test_corrupte (iRcosmo):
#cosmo, R = iRcosmo
#if WLanalysis.TestFitsComplete(SIMfn(i, cosmo, R))==False:
#print SIMfn(i, cosmo, R)
#return 1
#else:
#return 0
#Rcosmo = [[ cosmo, R] for R in R_arr for cosmo in cosmo_arr]
#badfiles = array(pool.map(test_corrupte, Rcosmo))
#save(KS_dir+'badfiles.npy',badfiles)
######################################################
### (1)create KS map, uncomment next 4 lines #########
示例11: sqrt
if kappa_temp>0:
theta = sqrt((x_fore-x_back)**2+(y_fore-y_back)**2)
print '%i\t%s\t%.2f\t%.3f\t%.3f\t%.4f\t%.6f'%(i, jj,log10(jMvir), z_fore, z_back, rad2arcmin(theta), kappa_temp)
return ikappa
#a=map(kappa_individual_gal, randint(0,len(idx_back)-1,5))
step=2e3
def temp (ix):
print ix
temp_fn = obsPK_dir+'temp/kappa_proj%i_%07d.npy'%(Wx, ix)
if not os.path.isfile(temp_fn):
kappa_all = map(kappa_individual_gal, arange(ix, amin([len(idx_back), ix+step])))
np.save(temp_fn,kappa_all)
pool = MPIPool()
ix_arr = arange(0, len(idx_back), step)
pool.map(temp, ix_arr)
all_kappa_proj = concatenate([np.load(obsPK_dir+'temp/kappa_proj%i_%07d.npy'%(Wx, ix)) for ix in ix_arr])
np.save(obsPK_dir+'kappa_predict_W%i.npy'%(Wx), all_kappa_proj)
#########################################################
####################### plotting correlation ############
#########################################################
make_predict_maps = 0
plot_predict_maps = 0
peak_proj_vs_lensing = 1
cross_correlate = 0
#kmap_predict_Gen = lambda Wx, sigmaG: np.load(obsPK_dir+'maps/r20arcmin_varyingcNFW_VO06/kmap_W%i_predict_sigmaG%02d.npy'%(Wx, sigmaG*10))
示例12: main
def main():
#################################################
############Option parsing#######################
#################################################
#Parse command line options
parser = argparse.ArgumentParser()
parser.add_argument("-f","--file",dest="options_file",action="store",type=str,help="analysis options file")
parser.add_argument("-v","--verbose",dest="verbose",action="store_true",default=False,help="turn on verbosity")
parser.add_argument("-vv","--verbose_plus",dest="verbose_plus",action="store_true",default=False,help="turn on additional verbosity")
parser.add_argument("-m","--mask_scale",dest="mask_scale",action="store_true",default=False,help="scale peaks and power spectrum to unmasked area")
parser.add_argument("-c","--cut_convergence",dest="cut_convergence",action="store",default=None,help="select convergence values in (min,max) to compute the likelihood. Safe for single descriptor only!!")
parser.add_argument("-g","--group_subfields",dest="group_subfields",action="store_true",default=False,help="group feature realizations by taking the mean over subfields, this makes a big difference in the covariance matrix")
parser.add_argument("-s","--save_features",dest="save_features",action="store_true",default=False,help="save features profiles")
parser.add_argument("-ss","--save",dest="save",action="store_true",default=False,help="save the best fits and corresponding chi2")
parser.add_argument("-p","--prefix",dest="prefix",action="store",default="",help="prefix of the emulator to pickle")
parser.add_argument("-l","--likelihood",dest="likelihood",action="store_true",default=False,help="save the likelihood cubes for the mocks")
parser.add_argument("-o","--observation",dest="observation",action="store_true",default=False,help="append the actual observation results to the mock results for direct comparison")
parser.add_argument("-d","--differentiate",dest="differentiate",action="store_true",default=False,help="differentiate the first minkowski functional to get the PDF")
cmd_args = parser.parse_args()
if cmd_args.options_file is None:
parser.print_help()
sys.exit(0)
#Set verbosity level
if cmd_args.verbose_plus:
logging.basicConfig(level=DEBUG_PLUS)
elif cmd_args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
#Initialize MPI Pool
try:
pool = MPIPool()
except:
pool = None
if (pool is not None) and (not pool.is_master()):
pool.wait()
sys.exit(0)
if pool is not None:
logging.info("Started MPI Pool.")
#################################################################################################################
#################Info gathering: covariance matrix, observation and emulator#####################################
#################################################################################################################
#start
start = time.time()
last_timestamp = start
#Instantiate a FeatureLoader object that will take care of the memory loading
feature_loader = FeatureLoader(cmd_args)
###########################################################################################################################################
#Use this model for the covariance matrix (from the new set of 50 N body simulations)
covariance_model = CFHTcov.getModels(root_path=feature_loader.options.get("simulations","root_path"))
logging.info("Measuring covariance matrix from model {0}".format(covariance_model))
#Load in the covariance matrix
fiducial_feature_ensemble = feature_loader.load_features(covariance_model)
fiducial_features = fiducial_feature_ensemble.mean()
features_covariance = fiducial_feature_ensemble.covariance()
#timestamp
now = time.time()
logging.info("covariance loaded in {0:.1f}s".format(now-last_timestamp))
last_timestamp = now
################################################################################################################################################
#Treat the 50N-body simulation set as data
observation = CFHTcov.getModels(root_path=feature_loader.options.get("observations","root_path"))
logging.info("Measuring the observations from {0}".format(observation))
#And load the observations
observed_feature = feature_loader.load_features(observation)
#timestamp
now = time.time()
logging.info("observation loaded in {0:.1f}s".format(now-last_timestamp))
last_timestamp = now
################################################################################################################################################
#Create a LikelihoodAnalysis instance by unpickling one of the emulators
emulators_dir = os.path.join(feature_loader.options.get("analysis","save_path"),"emulators")
emulator_file = os.path.join(emulators_dir,"emulator{0}_{1}.p".format(cmd_args.prefix,output_string(feature_loader.feature_string)))
logging.info("Unpickling emulator from {0}...".format(emulator_file))
analysis = LikelihoodAnalysis.load(emulator_file)
#timestamp
now = time.time()
logging.info("emulator unpickled in {0:.1f}s".format(now-last_timestamp))
#.........这里部分代码省略.........
示例13: savetxt
rndz=histogram(rndz,bins=edges)[0]/float(len(rndz))
all_PDF=m[:,range(89-70,89)]
all_PDF/=sum(all_PDF,axis=1)[:,newaxis]
average_PDF=average(all_PDF,axis=0)
savetxt(full_dir+'avgPDF_arr%s'%(i),average_PDF)
savetxt(full_dir+'pks_arr%s'%(i),pk)
savetxt(full_dir+'rnds_arr%s'%(i),rndz)
# for CFHT each bin has 3.5/70 probability, for SIM, each bin has 3.5/67 height
return all_PDF, pk, rndz, len(all_PDF)
avgPDF_arr = zeros(shape=(13,70))
pks_arr = zeros(shape=(13,67))
rnds_arr = zeros(shape=(13,67))
lena_arr = zeros(13)
p = MPIPool()
x = p.map(pdfs,range(1,14))
for i in range(1,14):
aa, peaks, rnds, lena=x[i-1]
avgPDF_arr_arr[i-1]=aa
pks_arr[i-1]=peaks
rnds_arr[i-1]=rnds
lena_arr[i-1]=lena
print i, lena
savetxt('/direct/astro+astronfs01/workarea/jia/CFHT/full_subfields/avgPDF_arr.ls',avgPDF_arr)
savetxt('/direct/astro+astronfs01/workarea/jia/CFHT/full_subfields/pks_arr.ls',pks_arr)
savetxt('/direct/astro+astronfs01/workarea/jia/CFHT/full_subfields/rnds_arr.ls',rnds_arr)
savetxt('/direct/astro+astronfs01/workarea/jia/CFHT/full_subfields/lena_arr.ls',lena_arr)
print 'done-done-done'
示例14: main
def main(argv):
##################
#These change a lot
numWaveforms = 16
numThreads = 12
ndim = 6*numWaveforms + 8
nwalkers = 2*ndim
iter=50
burnIn = 40
wfPlotNumber = 10
######################
# plt.ion()
fitSamples = 200
#Prepare detector
zero_1 = -5.56351644e+07
pole_1 = -1.38796386e+04
pole_real = -2.02559385e+07
pole_imag = 9885315.37450211
zeros = [zero_1,0 ]
poles = [ pole_real+pole_imag*1j, pole_real-pole_imag*1j, pole_1]
system = signal.lti(zeros, poles, 1E7 )
tempGuess = 77.89
gradGuess = 0.0483
pcRadGuess = 2.591182
pcLenGuess = 1.613357
#Create a detector model
detName = "conf/P42574A_grad%0.2f_pcrad%0.2f_pclen%0.2f.conf" % (0.05,2.5, 1.65)
det = Detector(detName, temperature=tempGuess, timeStep=1., numSteps=fitSamples*10, tfSystem=system)
det.LoadFields("P42574A_fields_v3.npz")
det.SetFields(pcRadGuess, pcLenGuess, gradGuess)
tempIdx = -8
gradIdx = -7
pcRadIdx = -6
pcLenIdx = -5
#and the remaining 4 are for the transfer function
fig_size = (20,10)
#Create a decent start guess by fitting waveform-by-waveform
wfFileName = "P42574A_512waveforms_%drisetimeculled.npz" % numWaveforms
if os.path.isfile(wfFileName):
data = np.load(wfFileName)
results = data['results']
wfs = data['wfs']
numWaveforms = wfs.size
else:
print "No saved waveforms available. Loading from Data"
exit(0)
#prep holders for each wf-specific param
r_arr = np.empty(numWaveforms)
phi_arr = np.empty(numWaveforms)
z_arr = np.empty(numWaveforms)
scale_arr = np.empty(numWaveforms)
t0_arr = np.empty(numWaveforms)
smooth_arr = np.ones(numWaveforms)*7.
simWfArr = np.empty((1,numWaveforms, fitSamples))
#Prepare the initial value arrays
for (idx, wf) in enumerate(wfs):
wf.WindowWaveformTimepoint(fallPercentage=.99)
r_arr[idx], phi_arr[idx], z_arr[idx], scale_arr[idx], t0_arr[idx], smooth_arr[idx] = results[idx]['x']
t0_arr[idx] += 10 #because i had a different windowing offset back in the day
#Plot the waveforms to take a look at the initial guesses
if False:
fig = plt.figure()
for (idx,wf) in enumerate(wfs):
print "WF number %d:" % idx
print " >>r: %f\n >>phi %f\n >>z %f\n >>e %f\n >>t0 %f\n >>smooth %f" % (r_arr[idx], phi_arr[idx], z_arr[idx], scale_arr[idx], t0_arr[idx], smooth_arr[idx])
ml_wf = det.GetSimWaveform(r_arr[idx], phi_arr[idx], z_arr[idx], scale_arr[idx]*100, t0_arr[idx], fitSamples, smoothing = smooth_arr[idx])
plt.plot(ml_wf, color="b")
plt.plot(wf.windowedWf, color="r")
value = raw_input(' --> Press q to quit, any other key to continue\n')
if value == 'q': exit(0)
#Initialize this thread's globals
initializeDetectorAndWaveforms(det, wfs)
#Initialize the multithreading
pool = MPIPool()
if not pool.is_master():
pool.wait()
sys.exit(0)
#.........这里部分代码省略.........
示例15: fit
def fit(self, kwargs):
""" Runs the sampler over the model and returns the flat chain of results
Parameters
----------
kwargs : dict
Containing the following information at a minimum:
- log_posterior : function
A function which takes a list of parameters and returns
the log posterior
- start : function|list|ndarray
Either a starting position, or a function that can be called
to generate a starting position
- save_dims : int, optional
Only return values for the first ``save_dims`` parameters.
Useful to remove numerous marginalisation parameters if running
low on memory or hard drive space.
- uid : str, optional
A unique identifier used to differentiate different fits
if two fits both serialise their chains and use the
same temporary directory
Returns
-------
dict
A dictionary with key "chains" containing the final
flattened chain of dimensions
``(num_dimensions, num_walkers * (num_steps - num_burn))``
"""
log_posterior = kwargs.get("log_posterior")
start = kwargs.get("start")
save_dims = kwargs.get("save_dims")
uid = kwargs.get("uid")
assert log_posterior is not None
assert start is not None
from emcee.utils import MPIPool
import emcee
try: # pragma: no cover
self.pool = MPIPool()
if not self.pool.is_master():
self.logger.info("Slave waiting")
self.master = False
self.pool.wait()
sys.exit(0)
else:
self.logger.info("MPIPool successful initialised and master found. "
"Running with %d cores." % self.pool.size)
except ImportError:
self.logger.info("mpi4py is not installed or not configured properly. "
"Ignore if running through python, not mpirun")
except ValueError as e: # pragma: no cover
self.logger.info("Unable to start MPI pool, expected normal python execution")
self.logger.info(str(e))
if callable(start):
num_dim = start().size
else:
num_dim = start.size
if self.num_walkers is None:
self.num_walkers = num_dim * 4
self.num_walkers = max(self.num_walkers, 20)
self.logger.debug("Fitting framework with %d dimensions" % num_dim)
self.logger.info("Using Ensemble Sampler")
sampler = emcee.EnsembleSampler(self.num_walkers, num_dim,
log_posterior,
pool=self.pool, live_dangerously=True)
emcee_wrapper = EmceeWrapper(sampler)
flat_chain = emcee_wrapper.run_chain(self.num_steps, self.num_burn,
self.num_walkers, num_dim,
start=start,
save_dim=save_dims,
temp_dir=self.temp_dir,
uid=uid,
save_interval=self.save_interval)
self.logger.debug("Fit finished")
if self.pool is not None: # pragma: no cover
self.pool.close()
self.logger.debug("Pool closed")
return {"chain": flat_chain}