Module diskchef.fitting.fitters
Fitter for diskchef
Expand source code
"""Fitter for diskchef"""
import inspect
import pickle
import logging
import os
from collections.abc import Iterable
from dataclasses import dataclass, field
from functools import partial
from multiprocessing import Pool
from matplotlib.figure import Figure
from pathlib import Path
from typing import Callable, List, Union, Literal, Dict
from itertools import product, cycle, zip_longest
import numpy as np
import ultranest
from astropy import units as u
from astropy.table import QTable
import corner
from matplotlib import pyplot as plt
import scipy.optimize
from emcee import EnsembleSampler
from matplotlib.colors import LogNorm
import diskchef.engine.other
from diskchef.engine.exceptions import CHEFValueError, CHEFNotImplementedError
from diskchef.engine.overplot_scatter import overplot_scatter, overplot_hexbin
@dataclass
class Parameter:
"""
Class that handles parameters for diskchef Fitters
Args:
name: str - name of the parameter
min: float
max: float - prior minimal and maximal value of the parameter
truth: optional, float - expected value of the parameter
format_: str - python format-string for the parameter output
log: bool - whether the logarithm of the valus should be used for fitting instead
Fields:
fitted: float - fitted value
fitted_error: float - 1-sigma error of the fitted value
fitted_error_up: float - upper error of the fitted value
fitted_error_down: float - lower error of the fitted value
"""
name: str
min: Union[u.Quantity, float] = None
max: Union[u.Quantity, float] = None
truth: float = None
format_: str = "{:.2f}"
log: bool = False
def __post_init__(self):
self.fitted = None
self.fitted_error = None
self.fitted_error_up = None
self.fitted_error_down = None
@property
def math_repr(self) -> str:
"""
Returns matplotlib/LaTeX-formatted representation of the parameter and its fitted value
"""
out = "$"
if self.fitted is None:
out += self.name
elif self.fitted_error is None:
out += f"{self.name} = {self.format_.format(self.fitted)}"
elif (self.fitted_error_down is None) or (self.fitted_error_up is None):
out += f"{self.name} = {self.format_.format(self.fitted)} ± {self.format_.format(self.fitted_error)}"
else:
out += f"{self.name} = {self.format_.format(self.fitted)}" \
f"^{{+{self.format_.format(self.fitted_error_up)}}}" \
f"_{{-{self.format_.format(self.fitted_error_down)}}}"
if self.truth is not None:
out += f" ({self.format_.format(self.truth)})"
out += "$"
return out
def __str__(self):
return f"${self.name}$"
def __eq__(self, other):
"""
Checks whether right parameter is within left parameter's error bar
"""
if self.fitted is None:
return False
elif self.fitted_error is None:
return self.fitted == other
elif (self.fitted_error_up is None) or (self.fitted_error_down is None):
return self.fitted - self.fitted_error <= other <= self.fitted + self.fitted_error
else:
return self.fitted - self.fitted_error_down <= other <= self.fitted + self.fitted_error_up
@dataclass
class Fitter:
"""Base class for finding parameters for lnprob which maximizes its output"""
lnprob: Callable
parameters: List[Parameter]
threads: int = None
progress: bool = False
hexbin: bool = True
fitter_kwargs: dict = field(default_factory=dict)
def __post_init__(self):
self.logger = logging.getLogger(__name__ + '.' + self.__class__.__qualname__)
self.logger.info("Creating an instance of %s", self.__class__.__qualname__)
self.logger.debug("With parameters: %s", self.__dict__)
self._table = None
self.sampler = None
self._check_lnprob()
for parameter in self.parameters:
if any(
fit_result is not None for fit_result in
[parameter.fitted, parameter.fitted_error,
parameter.fitted_error_up, parameter.fitted_error_down]
):
self.logger.info("Parameter %s is already fit! Clearing the fitted values.", parameter)
parameter.fitted = parameter.fitted_error = None
parameter.fitted_error_up = parameter.fitted_error_down = None
self.logger.info("Parameter %s is cleaned.", parameter)
def _check_lnprob(self):
if not callable(self.lnprob):
raise CHEFValueError("lnprob must be callable!")
defaults = [
param.default
for name, param
in inspect.signature(self.lnprob).parameters.items()
if name != 'self'
]
if defaults[0] is not inspect.Parameter.empty:
self.logger.warning("lnprob first argument has a default value!")
if not isinstance(defaults[0], Iterable):
self.logger.error("First argument of lnprob should be an array of parameters! Continuing anyway.")
if inspect.Parameter.empty in defaults[1:]:
self.logger.error("lnprob should have only one non-default argument! Continuing anyway.")
def _post_fit(self):
"""
Private method to run after the fitting is complete
"""
fig = self.corner()
fig.savefig("corner.pdf")
self.save()
def lnprob_fixed(self, *args, **kwargs):
"""Decorates self.lnprob so that minimum and maximum from self.parameters are considered"""
for parameter, arg in zip(self.parameters, *args):
if (parameter.min is not None) and (parameter.min > arg):
return -np.inf
if (parameter.max is not None) and (parameter.max < arg):
return -np.inf
try:
return self.lnprob(*args, **kwargs)
except Exception as e:
self.logger.error("lnprob crushed during the function call with %s %s",
args, kwargs)
self.logger.error("%s", e)
return -np.inf
@property
def table(self):
return self._table
def fit(
self,
*args, **kwargs
) -> Dict[str, Parameter]:
raise CHEFNotImplementedError
def corner(self, scatter_kwargs=None, hexbin_kwargs=None, **kwargs) -> Figure:
if scatter_kwargs is None:
scatter_kwargs = {}
if hexbin_kwargs is None:
hexbin_kwargs = {}
if "cmap" not in hexbin_kwargs.items():
hexbin_kwargs["cmap"] = "afmhot"
data = np.array([self.table[param.name] for param in self.parameters]).T
labels = [str(param) for param in self.parameters]
truths = [parameter.truth for parameter in self.parameters]
if "weight" in self.table.colnames:
weights = self.table["weight"]
cumsumweights = np.cumsum(weights)
mask = cumsumweights > 1e-4
else:
weights = np.ones_like(data[:, 0])
mask = Ellipsis
try:
fig = corner.corner(
data[mask], weights=weights[mask], labels=labels, show_titles=False, truths=truths, **kwargs
)
mins = []
maxs = []
medians = []
for param in self.parameters:
if param.fitted_error_down is not None:
min_ = param.fitted - param.fitted_error_down
elif param.fitted_error is not None:
min_ = param.fitted - param.fitted_error
else:
min_ = None
if param.fitted_error_up is not None:
max_ = param.fitted + param.fitted_error_up
elif param.fitted_error is not None:
max_ = param.fitted + param.fitted_error
else:
max_ = None
if param.fitted is not None:
median_ = param.fitted
else:
median_ = None
mins.append(min_)
maxs.append(max_)
medians.append(median_)
corner.overplot_lines(fig, mins, color='k', alpha=0.2, linestyle=(0, (5, 5)))
corner.overplot_lines(fig, maxs, color='k', alpha=0.2, linestyle=(0, (5, 5)))
corner.overplot_lines(fig, medians, color='k', alpha=1, linestyle=(0, (5, 5)))
self._decorate_corner(fig)
except AssertionError:
return Figure()
# overplot_scatter(fig, data, c=-self.table["lnprob"], norm=LogNorm(), **scatter_kwargs)
if self.hexbin:
overplot_hexbin(
fig, data[mask], C=-self.table["lnprob"][mask],
norm=diskchef.engine.other.LogNormMaxOrders(maxdepth=1e4), reduce_C_function=np.nanmin,
**hexbin_kwargs
)
return fig
def _decorate_corner(self, fig: Figure):
ndim = len(self.parameters)
axes = np.array(fig.axes).reshape((ndim, ndim))
for ax, parameter in zip(axes.diagonal(), self.parameters):
ax.set_title(parameter.math_repr, size="small")
@property
def parameters_dict(self) -> Dict[str, Parameter]:
return {parameter.name: parameter for parameter in self.parameters}
def save(self, filename: Path = "fitter.sav"):
"""Saves the fitter to a file"""
try:
with open(filename, "wb") as fff:
pickle.dump(self, fff)
self.logger.info("Fitter successfully pickled to %s", filename)
except Exception as e:
self.logger.error("Could not save fitter! %s", e)
@classmethod
def load(cls, filename: Path = "fitter.sav") -> "Fitter":
"""Return the loaded fitter from a file"""
try:
with open(filename, "rb") as fff:
fitter = pickle.load(fff)
fitter.logger.info("Fitter was successfully unpickled from %s", filename)
return fitter
except Exception as e:
raise e
@dataclass
class BruteForceFitter(Fitter):
n_points: Union[int, List[int]] = 10
def fit(
self,
*args, **kwargs
):
pars = []
if not hasattr(self.n_points, "__len__"):
self.n_points = [self.n_points] * len(self.parameters)
for parameter, length in zip_longest(self.parameters, self.n_points):
par_range = np.linspace(parameter.min, parameter.max, length)
pars.append(par_range)
all_parameter_combinations = product(*pars)
tbl = QTable(np.array([comb for comb in all_parameter_combinations]),
names=[par.name for par in self.parameters])
if self.threads != 1:
lnprob = partial(self.lnprob_fixed, *args, **kwargs)
with Pool(self.threads) as pool:
tbl["lnprob"] = pool.map(lnprob, tbl)
else:
tbl["lnprob"] = [self.lnprob_fixed(parameters, *args, **kwargs) for parameters in tbl]
tbl["weight"] = np.exp(tbl["lnprob"] - tbl["lnprob"].max())
self._table = tbl
argmax_row = tbl[np.argmax(tbl["lnprob"])]
for parameter, n_points in zip(self.parameters, self.n_points):
parameter.fitted = argmax_row[parameter.name]
err = (parameter.max - parameter.min) / (n_points - 1)
parameter.fitted_error_up = err
parameter.fitted_error_down = err
parameter.fitted_error = err
self._post_fit()
return self.parameters_dict
@dataclass
class EMCEEFitter(Fitter):
nwalkers: int = 100
nsteps: int = 100
burn_steps: int = 30
burn_strategy: Literal[None, 'best'] = None
def fit(
self,
*args, **kwargs
):
pos0 = (np.random.random((self.nwalkers, len(self.parameters))) # 0--1
* np.array([param.max - param.min for param in self.parameters]) # * range
+ np.array([param.min for param in self.parameters])) # + min
if self.threads != 1:
pool = Pool(self.threads)
else:
pool = None
sampler = EnsembleSampler(self.nwalkers, len(self.parameters), self.lnprob_fixed, args=args, kwargs=kwargs,
pool=pool, **self.fitter_kwargs)
sampler.run_mcmc(pos0, self.burn_steps, progress=self.progress)
if self.burn_strategy is None:
pos1 = None
elif self.burn_strategy == 'best':
pos1 = np.tile(sampler.flatchain[np.argmax(sampler.flatlnprobability)], [self.nwalkers, 1])
else:
raise CHEFValueError("burn_strategy should be None or 'best'")
sampler.run_mcmc(pos1, self.nsteps - self.burn_steps, progress=self.progress)
if pool is not None:
pool.close()
self.sampler = sampler
tbl = QTable(sampler.flatchain[self.burn_steps * self.nwalkers:], names=[par.name for par in self.parameters])
tbl["lnprob"] = sampler.flatlnprobability[self.burn_steps * self.nwalkers:]
self._table = tbl
for i, parameter in enumerate(self.parameters):
results = np.percentile(sampler.flatchain[self.burn_steps * self.nwalkers:, i], [16, 50, 84])
parameter.fitted = results[1]
parameter.fitted_error_up = results[2] - results[1]
parameter.fitted_error_down = results[1] - results[0]
parameter.fitted_error = (parameter.fitted_error_up + parameter.fitted_error_down) / 2
self._post_fit()
return self.parameters_dict
@dataclass
class UltraNestFitter(Fitter):
nwalkers: int = 100
nsteps: int = 100
transform: Callable = None
resume: Literal[True, 'resume', 'resume-similar', 'overwrite', 'subfolder'] = 'overwrite'
log_dir: Union[str, Path] = "ultranest"
run_kwargs: dict = field(default_factory=dict)
plot_corner: bool = True
storage_backend: Literal['hdf5', 'csv', 'tsv'] = 'hdf5'
DEFAULT_FOR_RUN_KWARGS = dict(
Lepsilon=0.01, # Increase when lnprob is inaccurate
frac_remain=0.05, # Decrease if lnprob is expected to have peaks
min_num_live_points=100,
dlogz=1.,
dKL=1.,
)
INFINITY = 1e50
def __post_init__(self):
super().__post_init__()
if self.transform is None:
self.transform = self.rescale
self.sampler = None
for key, value in self.DEFAULT_FOR_RUN_KWARGS.items():
if key not in self.run_kwargs:
self.run_kwargs[key] = value
def rescale(self, cube):
params = np.empty_like(cube)
for i, parameter in enumerate(self.parameters):
if parameter.log:
params[i] = 10 ** (cube[i] * (np.log10(parameter.max) - np.log10(parameter.min))
+ np.log10(parameter.min))
else:
params[i] = cube[i] * (parameter.max - parameter.min) + parameter.min
return params
def lnprob_fixed(self, *args, **kwargs):
return np.nan_to_num(
super().lnprob_fixed(*args, **kwargs),
neginf=-self.INFINITY, posinf=self.INFINITY, nan=-self.INFINITY
)
def save(self, filename: Path = "fitter.sav"):
"""Saves the fitter to a file. Does not save ultranest.ReactiveNestedSampler object!"""
sampler = self.sampler
del self.sampler
try:
with open(filename, "wb") as fff:
pickle.dump(self, fff)
self.logger.info("Fitter successfully pickled to %s. Sampler was not picked!", filename)
except Exception as e:
self.logger.error("Could not save fitter! %s", e)
finally:
self.sampler = sampler
def fit(
self,
*args, **kwargs
):
self.log_dir = Path(self.log_dir)
lnprob = partial(self.lnprob_fixed, *args, **kwargs)
lnprob.__name__ = self.lnprob.__name__
self.sampler = ultranest.ReactiveNestedSampler(
[str(param) for param in self.parameters],
lnprob,
self.transform,
log_dir=self.log_dir,
resume=self.resume,
storage_backend=self.storage_backend,
**self.fitter_kwargs
)
for i, result in enumerate(self.sampler.run_iter(**self.run_kwargs)):
if not self.sampler.use_mpi or self.sampler.mpi_rank == 0:
self.logger.info("Step %d: %s", i, result)
tbl = QTable(self.sampler.results['weighted_samples']['points'],
names=[par.name for par in self.parameters])
tbl["lnprob"] = self.sampler.results['weighted_samples']['logl']
tbl["lnprob"][tbl["lnprob"] <= -self.INFINITY] = -np.inf
tbl["weight"] = self.sampler.results['weighted_samples']['weights']
self._table = tbl
results = self.sampler.results['posterior']
for iparam, parameter in enumerate(self.parameters):
parameter.fitted = results['mean'][iparam]
parameter.fitted_error_up = results['errup'][iparam] - results['mean'][iparam]
parameter.fitted_error_down = results['mean'][iparam] - results['errlo'][iparam]
parameter.fitted_error = results['stdev'][iparam]
try:
self.sampler.plot()
except Exception as e:
self.logger.error("Could not plot! %s", e)
if self.plot_corner:
try:
fig = self.corner()
fig.savefig(self.log_dir / f"corner_{i:06d}.pdf")
fig.savefig(self.log_dir / "corner.pdf")
except ValueError as e:
self.logger.error("Could not make corner plot for %d:", i)
self.logger.error(e)
self.save(self.log_dir / f"fitter_{i:06d}.sav")
if not self.sampler.use_mpi or self.sampler.mpi_rank == 0:
self._post_fit()
return self.parameters_dict
@dataclass
class SciPyFitter(Fitter):
method: Union[None, str] = None
def fit(
self,
*args, **kwargs
):
scipy_result: scipy.optimize.OptimizeResult = scipy.optimize.minimize(
self.lnprob,
*args,
**kwargs
)
self.scipy_result = scipy_result
for parameter, result in zip(self.parameters, self.scipy_result.x):
parameter.fitted = result
return self.parameters_dict
Classes
class BruteForceFitter (lnprob: Callable, parameters: List[Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict = <factory>, n_points: Union[int, List[int]] = 10)
-
BruteForceFitter(lnprob: Callable, parameters: List[diskchef.fitting.fitters.Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict =
, n_points: Union[int, List[int]] = 10) Expand source code
@dataclass class BruteForceFitter(Fitter): n_points: Union[int, List[int]] = 10 def fit( self, *args, **kwargs ): pars = [] if not hasattr(self.n_points, "__len__"): self.n_points = [self.n_points] * len(self.parameters) for parameter, length in zip_longest(self.parameters, self.n_points): par_range = np.linspace(parameter.min, parameter.max, length) pars.append(par_range) all_parameter_combinations = product(*pars) tbl = QTable(np.array([comb for comb in all_parameter_combinations]), names=[par.name for par in self.parameters]) if self.threads != 1: lnprob = partial(self.lnprob_fixed, *args, **kwargs) with Pool(self.threads) as pool: tbl["lnprob"] = pool.map(lnprob, tbl) else: tbl["lnprob"] = [self.lnprob_fixed(parameters, *args, **kwargs) for parameters in tbl] tbl["weight"] = np.exp(tbl["lnprob"] - tbl["lnprob"].max()) self._table = tbl argmax_row = tbl[np.argmax(tbl["lnprob"])] for parameter, n_points in zip(self.parameters, self.n_points): parameter.fitted = argmax_row[parameter.name] err = (parameter.max - parameter.min) / (n_points - 1) parameter.fitted_error_up = err parameter.fitted_error_down = err parameter.fitted_error = err self._post_fit() return self.parameters_dict
Ancestors
Class variables
var n_points : Union[int, List[int]]
Methods
def fit(self, *args, **kwargs)
-
Expand source code
def fit( self, *args, **kwargs ): pars = [] if not hasattr(self.n_points, "__len__"): self.n_points = [self.n_points] * len(self.parameters) for parameter, length in zip_longest(self.parameters, self.n_points): par_range = np.linspace(parameter.min, parameter.max, length) pars.append(par_range) all_parameter_combinations = product(*pars) tbl = QTable(np.array([comb for comb in all_parameter_combinations]), names=[par.name for par in self.parameters]) if self.threads != 1: lnprob = partial(self.lnprob_fixed, *args, **kwargs) with Pool(self.threads) as pool: tbl["lnprob"] = pool.map(lnprob, tbl) else: tbl["lnprob"] = [self.lnprob_fixed(parameters, *args, **kwargs) for parameters in tbl] tbl["weight"] = np.exp(tbl["lnprob"] - tbl["lnprob"].max()) self._table = tbl argmax_row = tbl[np.argmax(tbl["lnprob"])] for parameter, n_points in zip(self.parameters, self.n_points): parameter.fitted = argmax_row[parameter.name] err = (parameter.max - parameter.min) / (n_points - 1) parameter.fitted_error_up = err parameter.fitted_error_down = err parameter.fitted_error = err self._post_fit() return self.parameters_dict
Inherited members
class EMCEEFitter (lnprob: Callable, parameters: List[Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict = <factory>, nwalkers: int = 100, nsteps: int = 100, burn_steps: int = 30, burn_strategy: Literal[None, 'best'] = None)
-
EMCEEFitter(lnprob: Callable, parameters: List[diskchef.fitting.fitters.Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict =
, nwalkers: int = 100, nsteps: int = 100, burn_steps: int = 30, burn_strategy: Literal[None, 'best'] = None) Expand source code
@dataclass class EMCEEFitter(Fitter): nwalkers: int = 100 nsteps: int = 100 burn_steps: int = 30 burn_strategy: Literal[None, 'best'] = None def fit( self, *args, **kwargs ): pos0 = (np.random.random((self.nwalkers, len(self.parameters))) # 0--1 * np.array([param.max - param.min for param in self.parameters]) # * range + np.array([param.min for param in self.parameters])) # + min if self.threads != 1: pool = Pool(self.threads) else: pool = None sampler = EnsembleSampler(self.nwalkers, len(self.parameters), self.lnprob_fixed, args=args, kwargs=kwargs, pool=pool, **self.fitter_kwargs) sampler.run_mcmc(pos0, self.burn_steps, progress=self.progress) if self.burn_strategy is None: pos1 = None elif self.burn_strategy == 'best': pos1 = np.tile(sampler.flatchain[np.argmax(sampler.flatlnprobability)], [self.nwalkers, 1]) else: raise CHEFValueError("burn_strategy should be None or 'best'") sampler.run_mcmc(pos1, self.nsteps - self.burn_steps, progress=self.progress) if pool is not None: pool.close() self.sampler = sampler tbl = QTable(sampler.flatchain[self.burn_steps * self.nwalkers:], names=[par.name for par in self.parameters]) tbl["lnprob"] = sampler.flatlnprobability[self.burn_steps * self.nwalkers:] self._table = tbl for i, parameter in enumerate(self.parameters): results = np.percentile(sampler.flatchain[self.burn_steps * self.nwalkers:, i], [16, 50, 84]) parameter.fitted = results[1] parameter.fitted_error_up = results[2] - results[1] parameter.fitted_error_down = results[1] - results[0] parameter.fitted_error = (parameter.fitted_error_up + parameter.fitted_error_down) / 2 self._post_fit() return self.parameters_dict
Ancestors
Class variables
var burn_steps : int
var burn_strategy : Literal[None, 'best']
var nsteps : int
var nwalkers : int
Methods
def fit(self, *args, **kwargs)
-
Expand source code
def fit( self, *args, **kwargs ): pos0 = (np.random.random((self.nwalkers, len(self.parameters))) # 0--1 * np.array([param.max - param.min for param in self.parameters]) # * range + np.array([param.min for param in self.parameters])) # + min if self.threads != 1: pool = Pool(self.threads) else: pool = None sampler = EnsembleSampler(self.nwalkers, len(self.parameters), self.lnprob_fixed, args=args, kwargs=kwargs, pool=pool, **self.fitter_kwargs) sampler.run_mcmc(pos0, self.burn_steps, progress=self.progress) if self.burn_strategy is None: pos1 = None elif self.burn_strategy == 'best': pos1 = np.tile(sampler.flatchain[np.argmax(sampler.flatlnprobability)], [self.nwalkers, 1]) else: raise CHEFValueError("burn_strategy should be None or 'best'") sampler.run_mcmc(pos1, self.nsteps - self.burn_steps, progress=self.progress) if pool is not None: pool.close() self.sampler = sampler tbl = QTable(sampler.flatchain[self.burn_steps * self.nwalkers:], names=[par.name for par in self.parameters]) tbl["lnprob"] = sampler.flatlnprobability[self.burn_steps * self.nwalkers:] self._table = tbl for i, parameter in enumerate(self.parameters): results = np.percentile(sampler.flatchain[self.burn_steps * self.nwalkers:, i], [16, 50, 84]) parameter.fitted = results[1] parameter.fitted_error_up = results[2] - results[1] parameter.fitted_error_down = results[1] - results[0] parameter.fitted_error = (parameter.fitted_error_up + parameter.fitted_error_down) / 2 self._post_fit() return self.parameters_dict
Inherited members
class Fitter (lnprob: Callable, parameters: List[Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict = <factory>)
-
Base class for finding parameters for lnprob which maximizes its output
Expand source code
@dataclass class Fitter: """Base class for finding parameters for lnprob which maximizes its output""" lnprob: Callable parameters: List[Parameter] threads: int = None progress: bool = False hexbin: bool = True fitter_kwargs: dict = field(default_factory=dict) def __post_init__(self): self.logger = logging.getLogger(__name__ + '.' + self.__class__.__qualname__) self.logger.info("Creating an instance of %s", self.__class__.__qualname__) self.logger.debug("With parameters: %s", self.__dict__) self._table = None self.sampler = None self._check_lnprob() for parameter in self.parameters: if any( fit_result is not None for fit_result in [parameter.fitted, parameter.fitted_error, parameter.fitted_error_up, parameter.fitted_error_down] ): self.logger.info("Parameter %s is already fit! Clearing the fitted values.", parameter) parameter.fitted = parameter.fitted_error = None parameter.fitted_error_up = parameter.fitted_error_down = None self.logger.info("Parameter %s is cleaned.", parameter) def _check_lnprob(self): if not callable(self.lnprob): raise CHEFValueError("lnprob must be callable!") defaults = [ param.default for name, param in inspect.signature(self.lnprob).parameters.items() if name != 'self' ] if defaults[0] is not inspect.Parameter.empty: self.logger.warning("lnprob first argument has a default value!") if not isinstance(defaults[0], Iterable): self.logger.error("First argument of lnprob should be an array of parameters! Continuing anyway.") if inspect.Parameter.empty in defaults[1:]: self.logger.error("lnprob should have only one non-default argument! Continuing anyway.") def _post_fit(self): """ Private method to run after the fitting is complete """ fig = self.corner() fig.savefig("corner.pdf") self.save() def lnprob_fixed(self, *args, **kwargs): """Decorates self.lnprob so that minimum and maximum from self.parameters are considered""" for parameter, arg in zip(self.parameters, *args): if (parameter.min is not None) and (parameter.min > arg): return -np.inf if (parameter.max is not None) and (parameter.max < arg): return -np.inf try: return self.lnprob(*args, **kwargs) except Exception as e: self.logger.error("lnprob crushed during the function call with %s %s", args, kwargs) self.logger.error("%s", e) return -np.inf @property def table(self): return self._table def fit( self, *args, **kwargs ) -> Dict[str, Parameter]: raise CHEFNotImplementedError def corner(self, scatter_kwargs=None, hexbin_kwargs=None, **kwargs) -> Figure: if scatter_kwargs is None: scatter_kwargs = {} if hexbin_kwargs is None: hexbin_kwargs = {} if "cmap" not in hexbin_kwargs.items(): hexbin_kwargs["cmap"] = "afmhot" data = np.array([self.table[param.name] for param in self.parameters]).T labels = [str(param) for param in self.parameters] truths = [parameter.truth for parameter in self.parameters] if "weight" in self.table.colnames: weights = self.table["weight"] cumsumweights = np.cumsum(weights) mask = cumsumweights > 1e-4 else: weights = np.ones_like(data[:, 0]) mask = Ellipsis try: fig = corner.corner( data[mask], weights=weights[mask], labels=labels, show_titles=False, truths=truths, **kwargs ) mins = [] maxs = [] medians = [] for param in self.parameters: if param.fitted_error_down is not None: min_ = param.fitted - param.fitted_error_down elif param.fitted_error is not None: min_ = param.fitted - param.fitted_error else: min_ = None if param.fitted_error_up is not None: max_ = param.fitted + param.fitted_error_up elif param.fitted_error is not None: max_ = param.fitted + param.fitted_error else: max_ = None if param.fitted is not None: median_ = param.fitted else: median_ = None mins.append(min_) maxs.append(max_) medians.append(median_) corner.overplot_lines(fig, mins, color='k', alpha=0.2, linestyle=(0, (5, 5))) corner.overplot_lines(fig, maxs, color='k', alpha=0.2, linestyle=(0, (5, 5))) corner.overplot_lines(fig, medians, color='k', alpha=1, linestyle=(0, (5, 5))) self._decorate_corner(fig) except AssertionError: return Figure() # overplot_scatter(fig, data, c=-self.table["lnprob"], norm=LogNorm(), **scatter_kwargs) if self.hexbin: overplot_hexbin( fig, data[mask], C=-self.table["lnprob"][mask], norm=diskchef.engine.other.LogNormMaxOrders(maxdepth=1e4), reduce_C_function=np.nanmin, **hexbin_kwargs ) return fig def _decorate_corner(self, fig: Figure): ndim = len(self.parameters) axes = np.array(fig.axes).reshape((ndim, ndim)) for ax, parameter in zip(axes.diagonal(), self.parameters): ax.set_title(parameter.math_repr, size="small") @property def parameters_dict(self) -> Dict[str, Parameter]: return {parameter.name: parameter for parameter in self.parameters} def save(self, filename: Path = "fitter.sav"): """Saves the fitter to a file""" try: with open(filename, "wb") as fff: pickle.dump(self, fff) self.logger.info("Fitter successfully pickled to %s", filename) except Exception as e: self.logger.error("Could not save fitter! %s", e) @classmethod def load(cls, filename: Path = "fitter.sav") -> "Fitter": """Return the loaded fitter from a file""" try: with open(filename, "rb") as fff: fitter = pickle.load(fff) fitter.logger.info("Fitter was successfully unpickled from %s", filename) return fitter except Exception as e: raise e
Subclasses
Class variables
var fitter_kwargs : dict
var hexbin : bool
var lnprob : Callable
var parameters : List[Parameter]
var progress : bool
var threads : int
Static methods
def load(filename: pathlib.Path = 'fitter.sav') ‑> Fitter
-
Return the loaded fitter from a file
Expand source code
@classmethod def load(cls, filename: Path = "fitter.sav") -> "Fitter": """Return the loaded fitter from a file""" try: with open(filename, "rb") as fff: fitter = pickle.load(fff) fitter.logger.info("Fitter was successfully unpickled from %s", filename) return fitter except Exception as e: raise e
Instance variables
var parameters_dict : Dict[str, Parameter]
-
Expand source code
@property def parameters_dict(self) -> Dict[str, Parameter]: return {parameter.name: parameter for parameter in self.parameters}
var table
-
Expand source code
@property def table(self): return self._table
Methods
def corner(self, scatter_kwargs=None, hexbin_kwargs=None, **kwargs) ‑> matplotlib.figure.Figure
-
Expand source code
def corner(self, scatter_kwargs=None, hexbin_kwargs=None, **kwargs) -> Figure: if scatter_kwargs is None: scatter_kwargs = {} if hexbin_kwargs is None: hexbin_kwargs = {} if "cmap" not in hexbin_kwargs.items(): hexbin_kwargs["cmap"] = "afmhot" data = np.array([self.table[param.name] for param in self.parameters]).T labels = [str(param) for param in self.parameters] truths = [parameter.truth for parameter in self.parameters] if "weight" in self.table.colnames: weights = self.table["weight"] cumsumweights = np.cumsum(weights) mask = cumsumweights > 1e-4 else: weights = np.ones_like(data[:, 0]) mask = Ellipsis try: fig = corner.corner( data[mask], weights=weights[mask], labels=labels, show_titles=False, truths=truths, **kwargs ) mins = [] maxs = [] medians = [] for param in self.parameters: if param.fitted_error_down is not None: min_ = param.fitted - param.fitted_error_down elif param.fitted_error is not None: min_ = param.fitted - param.fitted_error else: min_ = None if param.fitted_error_up is not None: max_ = param.fitted + param.fitted_error_up elif param.fitted_error is not None: max_ = param.fitted + param.fitted_error else: max_ = None if param.fitted is not None: median_ = param.fitted else: median_ = None mins.append(min_) maxs.append(max_) medians.append(median_) corner.overplot_lines(fig, mins, color='k', alpha=0.2, linestyle=(0, (5, 5))) corner.overplot_lines(fig, maxs, color='k', alpha=0.2, linestyle=(0, (5, 5))) corner.overplot_lines(fig, medians, color='k', alpha=1, linestyle=(0, (5, 5))) self._decorate_corner(fig) except AssertionError: return Figure() # overplot_scatter(fig, data, c=-self.table["lnprob"], norm=LogNorm(), **scatter_kwargs) if self.hexbin: overplot_hexbin( fig, data[mask], C=-self.table["lnprob"][mask], norm=diskchef.engine.other.LogNormMaxOrders(maxdepth=1e4), reduce_C_function=np.nanmin, **hexbin_kwargs ) return fig
def fit(self, *args, **kwargs) ‑> Dict[str, Parameter]
-
Expand source code
def fit( self, *args, **kwargs ) -> Dict[str, Parameter]: raise CHEFNotImplementedError
def lnprob_fixed(self, *args, **kwargs)
-
Decorates self.lnprob so that minimum and maximum from self.parameters are considered
Expand source code
def lnprob_fixed(self, *args, **kwargs): """Decorates self.lnprob so that minimum and maximum from self.parameters are considered""" for parameter, arg in zip(self.parameters, *args): if (parameter.min is not None) and (parameter.min > arg): return -np.inf if (parameter.max is not None) and (parameter.max < arg): return -np.inf try: return self.lnprob(*args, **kwargs) except Exception as e: self.logger.error("lnprob crushed during the function call with %s %s", args, kwargs) self.logger.error("%s", e) return -np.inf
def save(self, filename: pathlib.Path = 'fitter.sav')
-
Saves the fitter to a file
Expand source code
def save(self, filename: Path = "fitter.sav"): """Saves the fitter to a file""" try: with open(filename, "wb") as fff: pickle.dump(self, fff) self.logger.info("Fitter successfully pickled to %s", filename) except Exception as e: self.logger.error("Could not save fitter! %s", e)
class Parameter (name: str, min: Union[astropy.units.quantity.Quantity, float] = None, max: Union[astropy.units.quantity.Quantity, float] = None, truth: float = None, format_: str = '{:.2f}', log: bool = False)
-
Class that handles parameters for diskchef Fitters
Args
name
- str - name of the parameter
min
- float
max
- float - prior minimal and maximal value of the parameter
truth
- optional, float - expected value of the parameter
format_
- str - python format-string for the parameter output
log
- bool - whether the logarithm of the valus should be used for fitting instead
Fields
fitted: float - fitted value fitted_error: float - 1-sigma error of the fitted value fitted_error_up: float - upper error of the fitted value fitted_error_down: float - lower error of the fitted value
Expand source code
@dataclass class Parameter: """ Class that handles parameters for diskchef Fitters Args: name: str - name of the parameter min: float max: float - prior minimal and maximal value of the parameter truth: optional, float - expected value of the parameter format_: str - python format-string for the parameter output log: bool - whether the logarithm of the valus should be used for fitting instead Fields: fitted: float - fitted value fitted_error: float - 1-sigma error of the fitted value fitted_error_up: float - upper error of the fitted value fitted_error_down: float - lower error of the fitted value """ name: str min: Union[u.Quantity, float] = None max: Union[u.Quantity, float] = None truth: float = None format_: str = "{:.2f}" log: bool = False def __post_init__(self): self.fitted = None self.fitted_error = None self.fitted_error_up = None self.fitted_error_down = None @property def math_repr(self) -> str: """ Returns matplotlib/LaTeX-formatted representation of the parameter and its fitted value """ out = "$" if self.fitted is None: out += self.name elif self.fitted_error is None: out += f"{self.name} = {self.format_.format(self.fitted)}" elif (self.fitted_error_down is None) or (self.fitted_error_up is None): out += f"{self.name} = {self.format_.format(self.fitted)} ± {self.format_.format(self.fitted_error)}" else: out += f"{self.name} = {self.format_.format(self.fitted)}" \ f"^{{+{self.format_.format(self.fitted_error_up)}}}" \ f"_{{-{self.format_.format(self.fitted_error_down)}}}" if self.truth is not None: out += f" ({self.format_.format(self.truth)})" out += "$" return out def __str__(self): return f"${self.name}$" def __eq__(self, other): """ Checks whether right parameter is within left parameter's error bar """ if self.fitted is None: return False elif self.fitted_error is None: return self.fitted == other elif (self.fitted_error_up is None) or (self.fitted_error_down is None): return self.fitted - self.fitted_error <= other <= self.fitted + self.fitted_error else: return self.fitted - self.fitted_error_down <= other <= self.fitted + self.fitted_error_up
Class variables
var format_ : str
var log : bool
var max : Union[astropy.units.quantity.Quantity, float]
var min : Union[astropy.units.quantity.Quantity, float]
var name : str
var truth : float
Instance variables
var math_repr : str
-
Returns matplotlib/LaTeX-formatted representation of the parameter and its fitted value
Expand source code
@property def math_repr(self) -> str: """ Returns matplotlib/LaTeX-formatted representation of the parameter and its fitted value """ out = "$" if self.fitted is None: out += self.name elif self.fitted_error is None: out += f"{self.name} = {self.format_.format(self.fitted)}" elif (self.fitted_error_down is None) or (self.fitted_error_up is None): out += f"{self.name} = {self.format_.format(self.fitted)} ± {self.format_.format(self.fitted_error)}" else: out += f"{self.name} = {self.format_.format(self.fitted)}" \ f"^{{+{self.format_.format(self.fitted_error_up)}}}" \ f"_{{-{self.format_.format(self.fitted_error_down)}}}" if self.truth is not None: out += f" ({self.format_.format(self.truth)})" out += "$" return out
class SciPyFitter (lnprob: Callable, parameters: List[Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict = <factory>, method: Optional[None] = None)
-
SciPyFitter(lnprob: Callable, parameters: List[diskchef.fitting.fitters.Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict =
, method: Optional[str] = None) Expand source code
@dataclass class SciPyFitter(Fitter): method: Union[None, str] = None def fit( self, *args, **kwargs ): scipy_result: scipy.optimize.OptimizeResult = scipy.optimize.minimize( self.lnprob, *args, **kwargs ) self.scipy_result = scipy_result for parameter, result in zip(self.parameters, self.scipy_result.x): parameter.fitted = result return self.parameters_dict
Ancestors
Class variables
var method : Optional[None]
Methods
def fit(self, *args, **kwargs)
-
Expand source code
def fit( self, *args, **kwargs ): scipy_result: scipy.optimize.OptimizeResult = scipy.optimize.minimize( self.lnprob, *args, **kwargs ) self.scipy_result = scipy_result for parameter, result in zip(self.parameters, self.scipy_result.x): parameter.fitted = result return self.parameters_dict
Inherited members
class UltraNestFitter (lnprob: Callable, parameters: List[Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict = <factory>, nwalkers: int = 100, nsteps: int = 100, transform: Callable = None, resume: Literal[True, 'resume', 'resume-similar', 'overwrite', 'subfolder'] = 'overwrite', log_dir: Union[str, pathlib.Path] = 'ultranest', run_kwargs: dict = <factory>, plot_corner: bool = True, storage_backend: Literal['hdf5', 'csv', 'tsv'] = 'hdf5')
-
UltraNestFitter(lnprob: Callable, parameters: List[diskchef.fitting.fitters.Parameter], threads: int = None, progress: bool = False, hexbin: bool = True, fitter_kwargs: dict =
, nwalkers: int = 100, nsteps: int = 100, transform: Callable = None, resume: Literal[True, 'resume', 'resume-similar', 'overwrite', 'subfolder'] = 'overwrite', log_dir: Union[str, pathlib.Path] = 'ultranest', run_kwargs: dict = , plot_corner: bool = True, storage_backend: Literal['hdf5', 'csv', 'tsv'] = 'hdf5') Expand source code
@dataclass class UltraNestFitter(Fitter): nwalkers: int = 100 nsteps: int = 100 transform: Callable = None resume: Literal[True, 'resume', 'resume-similar', 'overwrite', 'subfolder'] = 'overwrite' log_dir: Union[str, Path] = "ultranest" run_kwargs: dict = field(default_factory=dict) plot_corner: bool = True storage_backend: Literal['hdf5', 'csv', 'tsv'] = 'hdf5' DEFAULT_FOR_RUN_KWARGS = dict( Lepsilon=0.01, # Increase when lnprob is inaccurate frac_remain=0.05, # Decrease if lnprob is expected to have peaks min_num_live_points=100, dlogz=1., dKL=1., ) INFINITY = 1e50 def __post_init__(self): super().__post_init__() if self.transform is None: self.transform = self.rescale self.sampler = None for key, value in self.DEFAULT_FOR_RUN_KWARGS.items(): if key not in self.run_kwargs: self.run_kwargs[key] = value def rescale(self, cube): params = np.empty_like(cube) for i, parameter in enumerate(self.parameters): if parameter.log: params[i] = 10 ** (cube[i] * (np.log10(parameter.max) - np.log10(parameter.min)) + np.log10(parameter.min)) else: params[i] = cube[i] * (parameter.max - parameter.min) + parameter.min return params def lnprob_fixed(self, *args, **kwargs): return np.nan_to_num( super().lnprob_fixed(*args, **kwargs), neginf=-self.INFINITY, posinf=self.INFINITY, nan=-self.INFINITY ) def save(self, filename: Path = "fitter.sav"): """Saves the fitter to a file. Does not save ultranest.ReactiveNestedSampler object!""" sampler = self.sampler del self.sampler try: with open(filename, "wb") as fff: pickle.dump(self, fff) self.logger.info("Fitter successfully pickled to %s. Sampler was not picked!", filename) except Exception as e: self.logger.error("Could not save fitter! %s", e) finally: self.sampler = sampler def fit( self, *args, **kwargs ): self.log_dir = Path(self.log_dir) lnprob = partial(self.lnprob_fixed, *args, **kwargs) lnprob.__name__ = self.lnprob.__name__ self.sampler = ultranest.ReactiveNestedSampler( [str(param) for param in self.parameters], lnprob, self.transform, log_dir=self.log_dir, resume=self.resume, storage_backend=self.storage_backend, **self.fitter_kwargs ) for i, result in enumerate(self.sampler.run_iter(**self.run_kwargs)): if not self.sampler.use_mpi or self.sampler.mpi_rank == 0: self.logger.info("Step %d: %s", i, result) tbl = QTable(self.sampler.results['weighted_samples']['points'], names=[par.name for par in self.parameters]) tbl["lnprob"] = self.sampler.results['weighted_samples']['logl'] tbl["lnprob"][tbl["lnprob"] <= -self.INFINITY] = -np.inf tbl["weight"] = self.sampler.results['weighted_samples']['weights'] self._table = tbl results = self.sampler.results['posterior'] for iparam, parameter in enumerate(self.parameters): parameter.fitted = results['mean'][iparam] parameter.fitted_error_up = results['errup'][iparam] - results['mean'][iparam] parameter.fitted_error_down = results['mean'][iparam] - results['errlo'][iparam] parameter.fitted_error = results['stdev'][iparam] try: self.sampler.plot() except Exception as e: self.logger.error("Could not plot! %s", e) if self.plot_corner: try: fig = self.corner() fig.savefig(self.log_dir / f"corner_{i:06d}.pdf") fig.savefig(self.log_dir / "corner.pdf") except ValueError as e: self.logger.error("Could not make corner plot for %d:", i) self.logger.error(e) self.save(self.log_dir / f"fitter_{i:06d}.sav") if not self.sampler.use_mpi or self.sampler.mpi_rank == 0: self._post_fit() return self.parameters_dict
Ancestors
Class variables
var DEFAULT_FOR_RUN_KWARGS
var INFINITY
var log_dir : Union[str, pathlib.Path]
var nsteps : int
var nwalkers : int
var plot_corner : bool
var resume : Literal[True, 'resume', 'resume-similar', 'overwrite', 'subfolder']
var run_kwargs : dict
var storage_backend : Literal['hdf5', 'csv', 'tsv']
var transform : Callable
Methods
def fit(self, *args, **kwargs)
-
Expand source code
def fit( self, *args, **kwargs ): self.log_dir = Path(self.log_dir) lnprob = partial(self.lnprob_fixed, *args, **kwargs) lnprob.__name__ = self.lnprob.__name__ self.sampler = ultranest.ReactiveNestedSampler( [str(param) for param in self.parameters], lnprob, self.transform, log_dir=self.log_dir, resume=self.resume, storage_backend=self.storage_backend, **self.fitter_kwargs ) for i, result in enumerate(self.sampler.run_iter(**self.run_kwargs)): if not self.sampler.use_mpi or self.sampler.mpi_rank == 0: self.logger.info("Step %d: %s", i, result) tbl = QTable(self.sampler.results['weighted_samples']['points'], names=[par.name for par in self.parameters]) tbl["lnprob"] = self.sampler.results['weighted_samples']['logl'] tbl["lnprob"][tbl["lnprob"] <= -self.INFINITY] = -np.inf tbl["weight"] = self.sampler.results['weighted_samples']['weights'] self._table = tbl results = self.sampler.results['posterior'] for iparam, parameter in enumerate(self.parameters): parameter.fitted = results['mean'][iparam] parameter.fitted_error_up = results['errup'][iparam] - results['mean'][iparam] parameter.fitted_error_down = results['mean'][iparam] - results['errlo'][iparam] parameter.fitted_error = results['stdev'][iparam] try: self.sampler.plot() except Exception as e: self.logger.error("Could not plot! %s", e) if self.plot_corner: try: fig = self.corner() fig.savefig(self.log_dir / f"corner_{i:06d}.pdf") fig.savefig(self.log_dir / "corner.pdf") except ValueError as e: self.logger.error("Could not make corner plot for %d:", i) self.logger.error(e) self.save(self.log_dir / f"fitter_{i:06d}.sav") if not self.sampler.use_mpi or self.sampler.mpi_rank == 0: self._post_fit() return self.parameters_dict
def rescale(self, cube)
-
Expand source code
def rescale(self, cube): params = np.empty_like(cube) for i, parameter in enumerate(self.parameters): if parameter.log: params[i] = 10 ** (cube[i] * (np.log10(parameter.max) - np.log10(parameter.min)) + np.log10(parameter.min)) else: params[i] = cube[i] * (parameter.max - parameter.min) + parameter.min return params
def save(self, filename: pathlib.Path = 'fitter.sav')
-
Saves the fitter to a file. Does not save ultranest.ReactiveNestedSampler object!
Expand source code
def save(self, filename: Path = "fitter.sav"): """Saves the fitter to a file. Does not save ultranest.ReactiveNestedSampler object!""" sampler = self.sampler del self.sampler try: with open(filename, "wb") as fff: pickle.dump(self, fff) self.logger.info("Fitter successfully pickled to %s. Sampler was not picked!", filename) except Exception as e: self.logger.error("Could not save fitter! %s", e) finally: self.sampler = sampler
Inherited members