Source code for pint.scripts.event_optimize_multiple

#!/usr/bin/env python -W ignore::FutureWarning -W ignore::UserWarning -W ignore::DeprecationWarning
import argparse
import contextlib
import sys
import pickle

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from astropy.coordinates import SkyCoord
import pint.logging
from loguru import logger as log

log.remove()
log.add(
    sys.stderr,
    level="WARNING",
    colorize=True,
    format=pint.logging.format,
    filter=pint.logging.LogFilter(),
)
import pint.fermi_toas as fermi
import pint.models
import pint.toa as toa
from pint.templates import lctemplate, lcfitters
from pint.residuals import Residuals
from pint.mcmc_fitter import CompositeMCMCFitter
from pint.observatory.satellite_obs import get_satellite_observatory
from pint.sampler import EmceeSampler
from pint.scripts.event_optimize import read_gaussfitfile


__all__ = ["main"]
# np.seterr(all='raise')

# initialization values
# Should probably figure a way to make these not global variables
maxpost = -9e99
numcalls = 0


[docs]def get_toas(evtfile, flags, tcoords=None, minweight=0, minMJD=0, maxMJD=100000): if evtfile[:-3] == "tim": usepickle = flags["usepickle"] if "usepickle" in flags else False ts = toa.get_TOAs(evtfile, usepickle=usepickle) # Prune out of range MJDs mask = np.logical_or( ts.get_mjds() < minMJD * u.day, ts.get_mjds() > maxMJD * u.day ) ts.table.remove_rows(mask) ts.table = ts.table.group_by("obs") else: if "usepickle" in flags and flags["usepickle"]: with contextlib.suppress(Exception): picklefile = toa._check_pickle(evtfile) or evtfile return toa.TOAs(picklefile) weightcol = flags["weightcol"] if "weightcol" in flags else None target = tcoords if weightcol == "CALC" else None tl = fermi.load_Fermi_TOAs( evtfile, weightcolumn=weightcol, targetcoord=target, minweight=minweight ) tl = filter(lambda t: (t.mjd.value > minMJD) and (t.mjd.value < maxMJD), tl) ts = toa.TOAs(toalist=tl) ts.filename = evtfile ts.compute_TDBs() ts.compute_posvels(ephem="DE421", planets=False) ts.pickle() log.info("There are %d events we will use" % len(ts.table)) return ts
[docs]def load_eventfiles(infile, tcoords=None, minweight=0, minMJD=0, maxMJD=100000): """Load events from multiple sources: The format of each line of infile is: <eventfile> <log_likelihood function> <template> [flags] Allowed flags are: setweights A multiplicative weight to apply to the probability function for this eventfile usepickle Load from a pickle file weightcol The weight column in the fits file """ lines = open(infile, "r").read().split("\n") eventinfo = { "toas": [], "lnlikes": [], "templates": [], "weightcol": [], "setweights": [], } for line in lines: log.info(f"{line}") if len(line) == 0: continue try: words = line.split() flags = {} if len(words) > 3: kvs = words[3:] for i in range(0, len(flags), 2): k, v = kvs[i].lstrip("-"), kvs[i + 1] flags[k] = v ts = get_toas( words[0], flags, tcoords=tcoords, minweight=minweight, minMJD=minMJD, maxMJD=maxMJD, ) eventinfo["toas"].append(ts) log.info("%s has %d events" % (words[0], len(ts.table))) eventinfo["lnlikes"].append(words[1]) eventinfo["templates"].append(words[2]) if "setweights" in flags: eventinfo["setweights"].append(float(flags["setweights"])) else: eventinfo["setweights"].append(1.0) if "weightcol" in flags: eventinfo["weightcol"].append(flags["weightcol"]) else: eventinfo["weightcol"].append(None) except Exception as e: log.error(f"{str(e)}") log.error(f"Could not load {line}") return eventinfo
[docs]def lnlikelihood_prob(ftr, theta, index): phases = ftr.get_event_phases(index) phss = (phases.astype(np.float64) + theta[-1]) % 1 probs = ftr.get_template_vals(phss, index) if ftr.weights[index] is None: return np.log(probs).sum() else: return np.log(ftr.weights[index] * probs + 1.0 - ftr.weights[index]).sum()
[docs]def lnlikelihood_resid(ftr, theta, index): return -Residuals(toas=ftr.toas_list[index], model=ftr.model).chi2.value
[docs]def main(argv=None): parser = argparse.ArgumentParser( description="PINT tool for MCMC optimization of timing models using event data from multiple sources.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("eventfiles", help="Specify a file listing all event files") parser.add_argument("parfile", help="par file to read model from") parser.add_argument("--ft2", help="Path to FT2 file.", default=None) parser.add_argument( "--nwalkers", help="Number of MCMC walkers", type=int, default=200 ) parser.add_argument( "--burnin", help="Number of MCMC steps for burn in", type=int, default=100, ) parser.add_argument( "--nsteps", help="Number of MCMC steps to compute", type=int, default=1000, ) parser.add_argument( "--minMJD", help="Earliest MJD to use", type=float, default=54680.0 ) parser.add_argument( "--maxMJD", help="Latest MJD to use", type=float, default=57250.0 ) parser.add_argument( "--phs", help="Starting phase offset [0-1] (def is to measure)", type=float ) parser.add_argument( "--phserr", help="Error on starting phase", type=float, default=0.03 ) parser.add_argument( "--minWeight", help="Minimum weight to include", type=float, default=0.05, ) parser.add_argument( "--wgtexp", help="Raise computed weights to this power (or 0.0 to disable any rescaling of weights)", type=float, default=0.0, ) parser.add_argument( "--testWeights", help="Make plots to evalute weight cuts?", default=False, action="store_true", ) parser.add_argument( "--initerrfact", help="Multiply par file errors by this factor when initializing walker starting values", type=float, default=0.1, ) parser.add_argument( "--priorerrfact", help="Multiple par file errors by this factor when setting gaussian prior widths", type=float, default=10.0, ) parser.add_argument( "--samples", help="Pickle file containing samples from a previous run", default=None, ) parser.add_argument( "--log-level", type=str, choices=("TRACE", "DEBUG", "INFO", "WARNING", "ERROR"), default=pint.logging.script_level, help="Logging level", dest="loglevel", ) global nwalkers, nsteps, ftr args = parser.parse_args(argv) log.remove() log.add( sys.stderr, level=args.loglevel, colorize=True, format=pint.logging.format, filter=pint.logging.LogFilter(), ) parfile = args.parfile if args.ft2 is not None: # Instantiate Fermi observatory once so it gets added to the observatory registry get_satellite_observatory("Fermi", args.ft2) nwalkers = args.nwalkers burnin = args.burnin nsteps = args.nsteps if burnin >= nsteps: log.error("burnin must be < nsteps") sys.exit(1) nbins = 256 # For likelihood calculation based on gaussians file outprof_nbins = 256 # in the text file, for pygaussfit.py, for instance minMJD = args.minMJD maxMJD = args.maxMJD # Usually set by coverage of IERS file minWeight = args.minWeight wgtexp = args.wgtexp # Read in initial model modelin = pint.models.get_model(parfile) # Set the target coords for automatic weighting if necessary if "ELONG" in modelin.params: tc = SkyCoord( modelin.ELONG.quantity, modelin.ELAT.quantity, frame="barycentrictrueecliptic", ) else: tc = SkyCoord(modelin.RAJ.quantity, modelin.DECJ.quantity, frame="icrs") eventinfo = load_eventfiles( args.eventfiles, tcoords=tc, minweight=minWeight, minMJD=minMJD, maxMJD=maxMJD ) nsets = len(eventinfo["toas"]) log.info( "Total number of events:\t%d" % np.array([len(t.table) for t in eventinfo["toas"]]).sum() ) log.info("Total number of datasets:\t%d" % nsets) funcs = {"prob": lnlikelihood_prob, "resid": lnlikelihood_resid} lnlike_funcs = [None] * nsets wlist = [None] * nsets gtemplates = [None] * nsets # Loop over all TOA sets for i in range(nsets): # Determine lnlikelihood function for this set try: lnlike_funcs[i] = funcs[eventinfo["lnlikes"][i]] except: raise ValueError(f'{eventinfo["lnlikes"][i]} is not a recognized function') # Load in weights ts = eventinfo["toas"][i] if eventinfo["weightcol"][i] is not None: if eventinfo["weightcol"][i] == "CALC": weights = np.asarray([x["weight"] for x in ts.table["flags"]]) log.info( "Original weights have min / max weights %.3f / %.3f" % (weights.min(), weights.max()) ) # Rescale the weights, if requested (by having wgtexp != 0.0) if wgtexp != 0.0: weights **= wgtexp wmx, wmn = weights.max(), weights.min() # make the highest weight = 1, but keep min weight the same weights = wmn + ((weights - wmn) * (1.0 - wmn) / (wmx - wmn)) for ii, x in enumerate(ts.table["flags"]): x["weight"] = weights[ii] weights = np.asarray([x["weight"] for x in ts.table["flags"]]) log.info( "There are %d events, with min / max weights %.3f / %.3f" % (len(weights), weights.min(), weights.max()) ) else: weights = None log.info("There are %d events, no weights are being used." % ts.ntoas) wlist[i] = weights # Load in templates tname = eventinfo["templates"][i] if tname == "none": continue if tname[-6:] == "pickle" or tname == "analytic": # Analytic template try: gtemplate = pickle.load(file(tname)) except Exception: phases = (modelin.phase(ts)[1].value).astype(np.float64) % 1 gtemplate = lctemplate.get_gauss2() lcf = lcfitters.LCFitter(gtemplate, phases, weights=wlist[i]) lcf.fit(unbinned=False) pickle.dump( gtemplate, file("%s_template%d.pickle" % (jname, i), "wb"), protocol=2, ) phases = (modelin.phase(ts)[1].value).astype(np.float64) % 1 lcf = lcfitters.LCFitter( gtemplate, phases, weights=wlist[i], binned_bins=200 ) lcf.fit_position(unbinned=False) lcf.fit(overall_position_first=True, estimate_errors=False, unbinned=False) for prim in lcf.template: prim.free[:] = False lcf.template.norms.free[:] = False else: # Binned template gtemplate = read_gaussfitfile(tname, nbins) gtemplate /= gtemplate.mean() gtemplates[i] = gtemplate # Set the priors on the parameters in the model, before # instantiating the emcee_fitter # Currently, this adds a gaussian prior on each parameter # with width equal to the par file uncertainty * priorerrfact, # and then puts in some special cases. # *** This should be replaced/supplemented with a way to specify # more general priors on parameters that need certain bounds phs = 0.0 if args.phs is None else args.phs sampler = EmceeSampler(nwalkers) ftr = CompositeMCMCFitter( eventinfo["toas"], modelin, sampler, lnlike_funcs, templates=gtemplates, weights=wlist, phs=phs, phserr=args.phserr, minMJD=minMJD, maxMJD=maxMJD, ) fitkeys, fitvals, fiterrs = ftr.get_fit_keyvals() # Use this if you want to see the effect of setting minWeight if args.testWeights: log.info("Checking H-test vs weights") ftr.prof_vs_weights(use_weights=True) ftr.prof_vs_weights(use_weights=False) sys.exit() ftr.phaseogram(plotfile=f"{ftr.model.PSR.value}_pre.png") like_start = ftr.lnlikelihood(ftr, ftr.get_parameters()) log.info("Starting Pulse Likelihood:\t%f" % like_start) # Set up the initial conditions for the emcee walkers ndim = ftr.n_fit_params if args.samples is None: pos = None else: chains = pickle.load(file(args.samples)) chains = np.reshape(chains, [nwalkers, -1, ndim]) pos = chains[:, -1, :] ftr.fit_toas( nsteps, pos=pos, priorerrfact=args.priorerrfact, errfact=args.initerrfact ) def plot_chains(chain_dict, file=False): npts = len(chain_dict) fig, axes = plt.subplots(npts, 1, sharex=True, figsize=(8, 9)) for ii, name in enumerate(chain_dict.keys()): axes[ii].plot(chain_dict[name], color="k", alpha=0.3) axes[ii].set_ylabel(name) axes[npts - 1].set_xlabel("Step Number") fig.tight_layout() if file: fig.savefig(file) plt.close() else: plt.show() plt.close() chains = sampler.chains_to_dict(ftr.fitkeys) plot_chains(chains, file=f"{ftr.model.PSR.value}_chains.png") # Make the triangle plot. # samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params)) samples = np.transpose( sampler.sampler.get_chain(discard=burnin), (1, 0, 2) ).reshape((-1, ftr.n_fit_params)) with contextlib.suppress(ImportError): import corner fig = corner.corner( samples, labels=ftr.fitkeys, bins=50, truths=ftr.maxpost_fitvals, plot_contours=True, ) fig.savefig(f"{ftr.model.PSR.value}_triangle.png") plt.close() # Make a phaseogram with the 50th percentile values # ftr.set_params(dict(zip(ftr.fitkeys, np.percentile(samples, 50, axis=0)))) # Make a phaseogram with the best MCMC result ftr.set_parameters(ftr.maxpost_fitvals) ftr.phaseogram(plotfile=f"{ftr.model.PSR.value}_post.png") plt.close() with open(f"{ftr.model.PSR.value}_post.par", "w") as f: f.write(ftr.model.as_parfile()) # Print the best MCMC values and ranges ranges = map( lambda v: (v[1], v[2] - v[1], v[1] - v[0]), zip(*np.percentile(samples, [16, 50, 84], axis=0)), ) log.info("Post-MCMC values (50th percentile +/- (16th/84th percentile):") for name, vals in zip(ftr.fitkeys, ranges): log.info("%8s:" % name + "%25.15g (+ %12.5g / - %12.5g)" % vals) with open(f"{ftr.model.PSR.value}_results.txt", "w") as f: f.write("Post-MCMC values (50th percentile +/- (16th/84th percentile):\n") for name, vals in zip(ftr.fitkeys, ranges): f.write("%8s:" % name + " %25.15g (+ %12.5g / - %12.5g)\n" % vals) f.write("\nMaximum likelihood par file:\n") f.write(ftr.model.as_parfile()) with open(f"{ftr.model.PSR.value}_samples.pickle", "wb") as smppkl: pickle.dump(samples, smppkl)