"""The BT (Blandford & Teukolsky) model."""
import numpy as np
from pint.models.parameter import floatParameter
from pint.models.pulsar_binary import PulsarBinary
from pint.models.stand_alone_psr_binaries.BT_model import BTmodel
from pint.models.stand_alone_psr_binaries.BT_piecewise import BTpiecewise
from pint.models.timing_model import MissingParameter, TimingModel
import astropy.units as u
from pint import GMsun, Tsun, ls
from astropy.table import Table
from astropy.time import Time
from pint.models.parameter import (
MJDParameter,
floatParameter,
prefixParameter,
maskParameter,
)
from pint.toa_select import TOASelect
[docs]class BinaryBT(PulsarBinary):
"""Blandford and Teukolsky binary model.
This binary model is described in Blandford and Teukolshy 1976. It is
a relatively simple parametrized post-Keplerian model that does not
support Shapiro delay calculations.
The actual calculations for this are done in
:class:`pint.models.stand_alone_psr_binaries.BT_model.BTmodel`.
Parameters supported:
.. paramtable::
:class: pint.models.binary_bt.BinaryBT
Notes
-----
Because PINT's binary models all support specification of multiple orbital
frequency derivatives FBn, this is capable of behaving like the model called
BTX in tempo2. The model called BTX in tempo instead supports multiple
(non-interacting) companions, and that is not supported here. Neither can
PINT accept "BTX" as an alias for this model.
See Blandford & Teukolsky 1976, ApJ, 205, 580.
"""
register = True
def __init__(self):
super().__init__()
self.binary_model_name = "BT"
self.binary_model_class = BTmodel
self.add_param(
floatParameter(
name="GAMMA",
value=0.0,
units="second",
description="Time dilation & gravitational redshift",
)
)
self.remove_param("M2")
self.remove_param("SINI")
[docs] def validate(self):
"""Validate BT model parameters"""
super().validate()
for p in ("T0", "A1"):
if getattr(self, p).value is None:
raise MissingParameter("BT", p, f"{p} is required for BT")
# If any *DOT is set, we need T0
for p in ("PBDOT", "OMDOT", "EDOT", "A1DOT"):
if getattr(self, p).value is None:
getattr(self, p).value = "0"
getattr(self, p).frozen = True
if self.GAMMA.value is None:
self.GAMMA.value = "0"
self.GAMMA.frozen = True
"""The BT (Blandford & Teukolsky) model with piecewise orbital parameters.
See Blandford & Teukolsky 1976, ApJ, 205, 580.
"""
[docs]class BinaryBTPiecewise(PulsarBinary):
"""Model implementing the BT model with piecewise orbital parameters A1X and T0X. This model lets the user specify time ranges and fit for a different piecewise orbital parameter in each time range,
This is a PINT pulsar binary BTPiecewise model class, a subclass of PulsarBinary.
It is a wrapper for stand alone BTPiecewise class defined in
./stand_alone_psr_binary/BT_piecewise.py
The aim for this class is to connect the stand alone binary model with the PINT platform.
BTpiecewise special parameters, where xxxx denotes the 4-digit index of the piece:
T0X_xxxx Piecewise T0 values for piece
A1X_xxxx Piecewise A1 values for piece
XR1_xxxx Lower time boundary of piece
XR2_xxxx Upper time boundary of piece
"""
register = True
def __init__(self):
super(BinaryBTPiecewise, self).__init__()
self.binary_model_name = "BT_piecewise"
self.binary_model_class = BTpiecewise
self.add_param(
floatParameter(
name="GAMMA",
value=0.0,
units="second",
description="Time dilation & gravitational redshift",
)
)
self.A1_value_funcs = []
self.T0_value_funcs = []
self.remove_param("M2")
self.remove_param("SINI")
self.add_group_range(None, None)
self.add_piecewise_param(0, T0=0 * u.d)
self.add_piecewise_param(0, A1=0 * ls)
[docs] def add_group_range(
self,
group_start_mjd,
group_end_mjd,
piece_index=None,
):
"""Add an orbital piecewise parameter group range. If piece_index is not provided a new piece will be added with index equal to the number of pieces plus one. Pieces cannot have the duplicate pieces and cannot have the same index. A pair of consisting of a piecewise A1 and T0 may share an index and will act over the same piece range.
Parameters
----------
group_start_mjd : float or astropy.quantity.Quantity or astropy.time.Time
MJD for the piece lower boundary
group_end_mjd : float or astropy.quantity.Quantity or astropy.time.Time
MJD for the piece upper boundary
piece_index : int
Number to label the piece being added.
"""
if group_start_mjd is not None and group_end_mjd is not None:
if isinstance(group_start_mjd, Time):
group_start_mjd = group_start_mjd.mjd
elif isinstance(group_start_mjd, u.quantity.Quantity):
group_start_mjd = group_start_mjd.value
if isinstance(group_end_mjd, Time):
group_end_mjd = group_end_mjd.mjd
elif isinstance(group_end_mjd, u.quantity.Quantity):
group_end_mjd = group_end_mjd.value
elif group_start_mjd is None or group_end_mjd is None:
if group_start_mjd is None and group_end_mjd is not None:
group_start_mjd = group_end_mjd - 100
elif group_start_mjd is not None and group_end_mjd is None:
group_end_mjd = group_start_mjd + 100
else:
group_start_mjd = 50000
group_end_mjd = 60000
if piece_index is None:
dct = self.get_prefix_mapping_component("XR1_")
if len(list(dct.keys())) > 0:
piece_index = np.max(list(dct.keys())) + 1
else:
piece_index = 0
# check the validity of the desired group to add
if group_end_mjd is not None and group_start_mjd is not None:
if group_end_mjd <= group_start_mjd:
raise ValueError("Starting MJD is greater than ending MJD.")
elif piece_index < 0:
raise ValueError(
f"Invalid index for group: {piece_index} should be greater than or equal to 0"
)
elif piece_index > 9999:
raise ValueError(
f"Invalid index for group. Cannot index beyond 9999 (yet?)"
)
i = f"{int(piece_index):04d}"
self.add_param(
prefixParameter(
name="XR1_{0}".format(i),
units="MJD",
description="Beginning of paramX interval",
parameter_type="MJD",
time_scale="utc",
value=group_start_mjd,
)
)
self.add_param(
prefixParameter(
name="XR2_{0}".format(i),
units="MJD",
description="End of paramX interval",
parameter_type="MJD",
time_scale="utc",
value=group_end_mjd,
)
)
self.setup()
[docs] def remove_range(self, index):
"""Removes all orbital piecewise parameters associated with a given index/list of indices.
Parameters
----------
index : float, int, list, np.ndarray
Number or list/array of numbers corresponding to T0X/A1X indices to be removed from model.
"""
if (
isinstance(index, int)
or isinstance(index, float)
or isinstance(index, np.int64)
):
indices = [index]
elif not isinstance(index, list) or not isinstance(index, np.ndarray):
raise TypeError(
f"index must be a float, int, list, or array - not {type(index)}"
)
for index in indices:
index_rf = f"{int(index):04d}"
for prefix in ["T0X_", "A1X_", "XR1_", "XR2_"]:
if hasattr(self, f"{prefix+index_rf}"):
self.remove_param(prefix + index_rf)
if hasattr(self.binary_instance, "param_pieces"):
if len(self.binary_instance.param_pieces) > 0:
temp_piece_information = []
for item in self.binary_instance.param_pieces:
if item[0] != index_rf:
temp_piece_information.append(item)
self.binary_instance.param_pieces = temp_piece_information
# self.binary_instance.param_pieces = self.binary_instance.param_pieces.remove('index_rf')
self.validate()
self.setup()
[docs] def add_piecewise_param(self, piece_index, **kwargs):
"""Add an orbital piecewise parameter.
Parameters
----------
piece_index : int
Number to label the piece being added. Expected to match a set of piece boundaries.
param : str
Piecewise parameter label e.g. "T0" or "A1".
paramx : np.float128 or astropy.quantity.Quantity
Piecewise parameter value.
"""
for key in ("T0", "A1"):
if key in kwargs:
param = key
paramx = kwargs[key]
if key == "T0":
param_unit = u.d
if isinstance(paramx, u.quantity.Quantity):
paramx = paramx.value
elif isinstance(paramx, np.float128):
paramx = paramx
elif isinstance(paramx, Time):
paramx = paramx.mjd
else:
raise ValueError(
"Unspported data type '%s' for piecewise T0. Ensure the piecewise parameter value is a np.float128, Time or astropy.quantity.Quantity"
% type(paramx)
)
elif key == "A1":
param_unit = ls
if isinstance(paramx, u.quantity.Quantity):
paramx = paramx.value
elif isinstance(paramx, np.float64):
paramx = paramx
else:
raise ValueError(
"Unspported data type '%s' for piecewise A1. Ensure the piecewise parameter value is a np.float64 or astropy.quantity.Quantity"
% type(paramx)
)
key_found = True
if not key_found:
raise AttributeError(
"No piecewise parameters passed. Use T0 = / A1 = to declare a piecewise variable."
)
if piece_index is None:
dct = self.get_prefix_mapping_component(param + "X_")
if len(list(dct.keys())) > 0:
piece_index = np.max(list(dct.keys())) + 1
else:
piece_index = 0
elif int(piece_index) in self.get_prefix_mapping_component(param + "X_"):
raise ValueError(
"Index '%s' is already in use in this model. Please choose another."
% piece_index
)
i = f"{int(piece_index):04d}"
# handling if None are passed as arguments
if any(i is None for i in [param, param_unit, paramx]):
if param is not None:
# if parameter value or unit unset, set with default according to param
if param_unit is None:
param_unit = (getattr(self, param)).units
if paramx is None:
paramx = (getattr(self, param)).value
# check if name exists and is currently available
self.add_param(
prefixParameter(
name=param + f"X_{i}",
units=param_unit,
value=paramx,
description="Parameter" + param + "variation",
parameter_type="float",
frozen=False,
)
)
self.setup()
[docs] def setup(self):
"""Raises
------
ValueError
if there are values that have been added without name/ranges associated (should only be raised if add_piecewise_param has been side-stepped with an alternate method)
"""
super().setup()
for bpar in self.params:
self.register_deriv_funcs(self.d_binary_delay_d_xxxx, bpar)
# Setup the model isinstance
self.binary_instance = self.binary_model_class()
# piecewise T0's
T0X_mapping = self.get_prefix_mapping_component("T0X_")
T0Xs = {}
# piecewise A1's (doing piecewise A1's requires more thought and work)
A1X_mapping = self.get_prefix_mapping_component("A1X_")
A1Xs = {}
# piecewise parameter ranges XR1-piece lower bound
XR1_mapping = self.get_prefix_mapping_component("XR1_")
XR1s = {}
# piecewise parameter ranges XR2-piece upper bound
XR2_mapping = self.get_prefix_mapping_component("XR2_")
XR2s = {}
for index in XR1_mapping.values():
index = index.split("_")[1]
piece_index = f"{int(index):04d}"
if hasattr(self, f"T0X_{piece_index}"):
if getattr(self, f"T0X_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"T0X_{piece_index}", getattr(self, f"T0X_{piece_index}")
)
else:
self.binary_instance.add_binary_params(
f"T0X_{piece_index}", self.T0.value
)
if hasattr(self, f"A1X_{piece_index}"):
if hasattr(self, f"A1X_{piece_index}"):
if getattr(self, f"A1X_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"A1X_{piece_index}", getattr(self, f"A1X_{piece_index}")
)
else:
self.binary_instance.add_binary_params(
f"A1X_{piece_index}", self.A1.value
)
if hasattr(self, f"XR1_{piece_index}"):
if getattr(self, f"XR1_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"XR1_{piece_index}", getattr(self, f"XR1_{piece_index}")
)
else:
raise ValueError(f"No date provided to create a group with")
else:
raise ValueError(f"No name provided to create a group with")
if hasattr(self, f"XR2_{piece_index}"):
if getattr(self, f"XR2_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"XR2_{piece_index}", getattr(self, f"XR2_{piece_index}")
)
else:
raise ValueError(f"No date provided to create a group with")
else:
raise ValueError(f"No name provided to create a group with")
self.update_binary_object(None)
[docs] def validate(self):
"""Include catches for overlapping groups. etc
Raises
------
ValueError
if there are pieces with no associated boundaries (T0X_0000 does not have a corresponding XR1_0000/XR2_0000)
ValueError
if any boundaries overlap (as it makes TOA assignment to a single group ambiguous). i.e. XR1_0000<XR2_0000 and XR2_0000>XR1_0001
ValueError
if the number of lower and upper bounds don't match (should only be raised if XR1 is defined without XR2 and validate is run or vice versa)
"""
super().validate()
for p in ("T0", "A1"):
if getattr(self, p).value is None:
raise MissingParameter("BT", p, "%s is required for BT" % p)
# If any *DOT is set, we need T0
for p in ("PBDOT", "OMDOT", "EDOT", "A1DOT"):
if getattr(self, p).value is None:
getattr(self, p).set("0")
getattr(self, p).frozen = True
if getattr(self, p).value is not None:
if self.T0.value is None:
raise MissingParameter("BT", "T0", "T0 is required if *DOT is set")
if self.GAMMA.value is None:
self.GAMMA.set("0")
self.GAMMA.frozen = True
dct_plb = self.get_prefix_mapping_component("XR1_")
dct_pub = self.get_prefix_mapping_component("XR2_")
dct_T0X = self.get_prefix_mapping_component("T0X_")
dct_A1X = self.get_prefix_mapping_component("A1X_")
if len(dct_plb) > 0 and len(dct_pub) > 0:
ls_plb = list(dct_plb.items())
ls_pub = list(dct_pub.items())
ls_T0X = list(dct_T0X.items())
ls_A1X = list(dct_A1X.items())
j_plb = [((tup[1]).split("_"))[1] for tup in ls_plb]
j_pub = [((tup[1]).split("_"))[1] for tup in ls_pub]
j_T0X = [((tup[1]).split("_"))[1] for tup in ls_T0X]
j_A1X = [((tup[1]).split("_"))[1] for tup in ls_A1X]
if j_plb != j_pub:
raise ValueError(
f"Group boundary mismatch error. Number of detected lower bounds: {j_plb}. Number of detected upper bounds: {j_pub}"
)
if len(np.setdiff1d(j_plb, j_pub)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of XR1_/XR2_ parameters in the model"
)
if not len(ls_A1X) > 0:
if len(ls_pub) > 0 and len(ls_T0X) > 0:
if len(np.setdiff1d(j_pub, j_T0X)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of T0X groups, make sure they match there are corresponding group ranges (XR1/XR2)"
)
if not len(ls_T0X) > 0:
if len(ls_pub) > 0 and len(ls_A1X) > 0:
if len(np.setdiff1d(j_pub, j_A1X)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of A1X groups, make sure they match there are corresponding group ranges (/XR2)"
)
lb = [(getattr(self, tup[1])).value for tup in ls_plb]
ub = [(getattr(self, tup[1])).value for tup in ls_pub]
for i in range(len(lb)):
for j in range(len(lb)):
if i != j:
if max(lb[i], lb[j]) < min(ub[i], ub[j]):
raise ValueError(
f"Group boundary overlap detected. Make sure groups are not overlapping"
)
[docs] def paramx_per_toa(self, param_name, toas):
"""Find the piecewise parameter value each toa will reference during calculations
Parameters
----------
param_name : string
which piecewise parameter to show: 'A1'/'T0'. TODO this should raise an error if not present)
toa : pint.toa.TOA
Returns
-------
u.quantity.Quantity
length(toa) elements are T0X or A1X values to reference for each toa during binary calculations.
"""
condition = {}
tbl = toas.table
XR1_mapping = self.get_prefix_mapping_component("XR1_")
XR2_mapping = self.get_prefix_mapping_component("XR2_")
if not hasattr(self, "toas_selector"):
self.toas_selector = TOASelect(is_range=True)
if param_name[0:2] == "T0":
paramX_mapping = self.get_prefix_mapping_component("T0X_")
param_unit = u.d
elif param_name[0:2] == "A1":
paramX_mapping = self.get_prefix_mapping_component("A1X_")
param_unit = ls
else:
raise AttributeError(
"param '%s' not found. Please choose another. Currently implemented: 'T0' or 'A1' "
% param_name
)
for piece_index in paramX_mapping.keys():
r1 = getattr(self, XR1_mapping[piece_index]).quantity
r2 = getattr(self, XR2_mapping[piece_index]).quantity
condition[paramX_mapping[piece_index]] = (r1.mjd, r2.mjd)
select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"])
paramx = np.zeros(len(tbl)) * param_unit
for k, v in select_idx.items():
paramx[v] += getattr(self, k).quantity
for i in range(len(paramx)):
if paramx[i] == 0:
paramx[i] = (getattr(self, param_name[0:2])).value * param_unit
return paramx
[docs] def get_number_of_groups(self):
"""Get the number of piecewise parameters"""
return len(self.binary_instance.piecewise_parameter_information)
[docs] def which_group_is_toa_in(self, toa):
"""Find the group a toa belongs to based on the boundaries of groups passed to BT_piecewise
Parameters
----------
Returns
-------
list
str elements, look like ['0000','0001'] for two TOAs where one refences T0X/A1X.
"""
# if isinstance(toa, pint.toa.TOAs):
# pass
# else:
# raise TypeError(f'toa must be a Time or pint.toa.TOAs - not {type(toa)}')
tbl = toa.table
condition = {}
XR1_mapping = self.get_prefix_mapping_component("XR1_")
XR2_mapping = self.get_prefix_mapping_component("XR2_")
if not hasattr(self, "toas_selector"):
self.toas_selector = TOASelect(is_range=True)
boundaries = {}
for piece_index in XR1_mapping.keys():
r1 = getattr(self, XR1_mapping[piece_index]).quantity
r2 = getattr(self, XR2_mapping[piece_index]).quantity
condition[(XR1_mapping[piece_index]).split("_")[-1]] = (r1.mjd, r2.mjd)
select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"])
paramx = np.empty(len(tbl), dtype="<U4")
for k, v in select_idx.items():
paramx[v] = k
return paramx.tolist()