Module diskchef.engine.plot
Module with plotting helper routines for diskchef
Expand source code
"""Module with plotting helper routines for diskchef"""
import copy
import logging
from dataclasses import dataclass, field
import matplotlib.ticker
import numpy as np
from typing import Literal, Union, List
from astropy import units as u
from astropy.visualization import quantity_support
quantity_support()
import matplotlib.axes
import matplotlib.scale
import matplotlib.colors
from matplotlib import pyplot as plt
from matplotlib.ticker import LogFormatterMathtext
from diskchef.engine.ctable import CTable
from diskchef.engine.other import LogNormMaxOrders
from diskchef.engine.exceptions import CHEFValueError
from chemical_names import from_string
@dataclass
class Plot:
table: CTable
axes: matplotlib.axes.Axes = None
xscale: Union[Literal["linear", "log", "symlog", "logit"], matplotlib.scale.ScaleBase] = "log"
yscale: Union[Literal["linear", "log", "symlog", "logit"], matplotlib.scale.ScaleBase] = "linear"
margins: float = 0.
unit_format: Literal["latex", "cds", None] = "latex"
maxdepth: float = 1e6
def __post_init__(self):
self.logger = logging.getLogger(__name__ + '.' + self.__class__.__qualname__)
self.logger.info("Creating an instance of %s", self.__class__.__qualname__)
self.logger.debug("With parameters: %s", self.__dict__)
if self.axes is None:
self.axes = plt.axes()
def normalize_axes(self):
self.axes.set_xscale(self.xscale)
self.axes.set_yscale(self.yscale)
self.axes.margins(self.margins)
def formatted(self, unit: u.Unit):
if unit == u.dimensionless_unscaled:
return "[--]"
else:
return fr"[{unit.to_string(self.unit_format)}]"
@dataclass
class Plot2D(Plot):
"""2D visualization of a disk"""
data1: str = None
data2: str = None
x_axis: str = "Radius"
y_axis: str = "Height to radius"
norm: matplotlib.colors.Normalize = None
colorbar: bool = True
labels: bool = True
cmap: Union[matplotlib.colors.Colormap, str] = None
multiply_by: Union[str, float] = 1.
levels: u.Quantity = None
desired_max: u.Quantity = None
norm_lower: bool = False
def __post_init__(self):
super().__post_init__()
if self.norm is None:
self.norm = LogNormMaxOrders(vmax=self.desired_max, maxdepth=self.maxdepth)
try:
self.multiply_by = self.table[self.multiply_by]
except (KeyError, ValueError):
pass
self.table.check_zeros(self.data1)
self.table.check_zeros(self.data2)
data1_q = u.Quantity(self.table[self.data1] * self.multiply_by)
data1 = data1_q.value
if self.norm_lower:
data2_q = u.Quantity(self.table[self.data2] * self.multiply_by)
self.data_unit = data2_q.unit
data2 = data2_q.value
self.norm(data2)
else:
self.data_unit = data1_q.unit
self.norm(data1)
x_axis = self.table[self.x_axis].value
y_axis = self.table[self.y_axis].value
self.x_unit = self.table[self.x_axis].unit
self.y_unit = self.table[self.y_axis].unit
self.axes.set_xlabel(f"{self.x_axis} {self.formatted(self.x_unit)}")
self.axes.set_ylabel(f"{self.y_axis} {self.formatted(self.y_unit)}")
if self.levels is None:
minlevel = np.round(np.log10(self.norm.vmin))
maxlevel = np.round(np.log10(self.norm.vmax))
if maxlevel - minlevel == 1:
self.cbar_formatter = LogFormatterMathtext()
minlevel = np.round(np.log10(self.norm.vmin), 1)
maxlevel = np.round(np.log10(self.norm.vmax), 1)
else:
self.cbar_formatter = LogFormatterMathtext()
# as we need to convert self.levels to correct units and leave dimensionless
# noinspection PyTypeChecker
self.levels = np.logspace(minlevel, maxlevel, 13)
else:
self.levels = self.levels.to_value(self.data_unit)
if len(set(self.levels)) == 1:
self.levels = self.levels[0] * np.array([0.5, 1., 2.])
self.normalize_axes()
im = self.axes.tricontourf(
x_axis, y_axis,
data1,
levels=self.levels,
norm=self.norm,
extend="both",
cmap=self.cmap,
)
if self.data2 is not None:
data2 = u.Quantity(self.table[self.data2] * self.multiply_by).to_value(self.data_unit)
self.axes.tricontourf(
x_axis, -y_axis,
data2,
levels=self.levels,
norm=self.norm,
extend="both",
cmap=self.cmap,
)
self.axes.axhline(0, color="black")
try:
formatter = matplotlib.ticker.FuncFormatter(lambda x, pos: f'{abs(x):.2f}') # todo as function def
self.axes.yaxis.set_major_formatter(formatter)
except ValueError:
self.logger.warning("Could not fix yticks for negatives: %s", self.axes.get_yticklabels())
self.axes.margins(self.margins)
if self.colorbar:
im.set_clim(self.norm.vmin, self.norm.vmax)
self.cbar = self.axes.figure.colorbar(
im, ax=self.axes,
format=self.cbar_formatter
)
self.cbar.set_label(self.formatted(data1_q.unit), rotation="horizontal")
self.cbar.ax.minorticks_off()
if self.labels:
txt = self.axes.text(
0.05, 0.05, from_string(self.data2),
transform=self.axes.transAxes,
verticalalignment='bottom',
bbox=dict(
boxstyle="round",
ec=(1., 1., 1., 0.5),
fc=(1., 1., 1., 0.7),
)
)
if self.data2 is not None:
txt = self.axes.text(
0.05, 0.95, from_string(self.data1),
transform=self.axes.transAxes,
verticalalignment='top',
bbox=dict(
boxstyle="round",
ec=(1., 1., 1., 0.5),
fc=(1., 1., 1., 0.7),
)
)
def contours(
self,
data: str,
levels: Union[u.Quantity, List[float]],
x_axis: str = "Radius",
y_axis: str = "Height to radius",
clabel_kwargs: dict = None,
colors: Union[str, List[str]] = "black",
on_colorbar: bool = True,
location: Literal["upper", "bottom", "both"] = "both",
**kwargs
):
if clabel_kwargs is None:
clabel_kwargs = {}
data_q = u.Quantity(self.table[data])
data = data_q.value
dataunit = data_q.unit
if "fmt" not in clabel_kwargs.keys():
clabel_kwargs["fmt"] = f"%d {dataunit.to_string(self.unit_format)}"
x_axis = self.table[x_axis].to_value(self.x_unit)
y_axis = self.table[y_axis].to_value(self.y_unit)
if location == "both":
x_axis = [*x_axis, *x_axis]
y_axis = [*-y_axis, *y_axis]
data = [*data, *data]
elif location == "bottom":
y_axis = -y_axis
elif location == "upper":
pass
else:
raise CHEFValueError('location bust be one of ["upper", "bottom", "both"]')
conts = self.axes.tricontour(
x_axis, y_axis,
data,
levels=levels.to_value(dataunit),
colors=colors,
**kwargs
)
if on_colorbar:
try:
levels_as_data = levels.to_value(self.data_unit)
new_conts = copy.copy(conts)
new_conts.levels = levels_as_data
self.cbar.add_lines(new_conts)
except u.core.UnitConversionError as e:
self.logger.info(e)
try:
conts.clabel(levels.to_value(dataunit), use_clabeltext=True, inline=True, inline_spacing=1, **clabel_kwargs)
except ValueError as e:
self.logger.warning(e)
@dataclass
class Plot1D(Plot):
data: List[str] = None
x_axis: u.au = None
yscale: Union[Literal["linear", "log", "symlog", "logit"], matplotlib.scale.ScaleBase] = "log"
labels: bool = True
cmap: Union[matplotlib.colors.Colormap, str] = None
plot_kwargs: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
if not self.data:
raise ValueError("List of `data` arguments must be specified")
if self.x_axis is None:
if self.table.is_in_zr_regular_grid:
self.x_axis = u.Quantity(sorted(set(self.table.r)))
else:
self.x_axis = np.geomspace(np.min(self.table.r), np.max(self.table), 100)
self.x_unit = self.x_axis.unit
self.y_unit = (self.table[self.data[0]][0] * self.x_axis[0]).cgs.unit
self.normalize_axes()
self.axes.set_xlabel(f"Radius {self.formatted(self.x_unit)}")
self.axes.set_ylabel(f"{self.formatted(self.y_unit)}")
for colname in self.data:
data_to_plot = self.table.column_density(colname, self.x_axis).cgs
self.axes.plot(self.x_axis, data_to_plot, label=from_string(colname), **self.plot_kwargs)
self.axes.legend()
Classes
class Plot (table: CTable, axes: matplotlib.axes._axes.Axes = None, xscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', yscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'linear', margins: float = 0.0, unit_format: Literal['latex', 'cds', None] = 'latex', maxdepth: float = 1000000.0)
-
Plot(table: diskchef.engine.ctable.CTable, axes: matplotlib.axes._axes.Axes = None, xscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', yscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'linear', margins: float = 0.0, unit_format: Literal['latex', 'cds', None] = 'latex', maxdepth: float = 1000000.0)
Expand source code
@dataclass class Plot: table: CTable axes: matplotlib.axes.Axes = None xscale: Union[Literal["linear", "log", "symlog", "logit"], matplotlib.scale.ScaleBase] = "log" yscale: Union[Literal["linear", "log", "symlog", "logit"], matplotlib.scale.ScaleBase] = "linear" margins: float = 0. unit_format: Literal["latex", "cds", None] = "latex" maxdepth: float = 1e6 def __post_init__(self): self.logger = logging.getLogger(__name__ + '.' + self.__class__.__qualname__) self.logger.info("Creating an instance of %s", self.__class__.__qualname__) self.logger.debug("With parameters: %s", self.__dict__) if self.axes is None: self.axes = plt.axes() def normalize_axes(self): self.axes.set_xscale(self.xscale) self.axes.set_yscale(self.yscale) self.axes.margins(self.margins) def formatted(self, unit: u.Unit): if unit == u.dimensionless_unscaled: return "[--]" else: return fr"[{unit.to_string(self.unit_format)}]"
Subclasses
Class variables
var axes : matplotlib.axes._axes.Axes
var margins : float
var maxdepth : float
var table : CTable
var unit_format : Literal['latex', 'cds', None]
var xscale : Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase]
var yscale : Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase]
Methods
def formatted(self, unit: astropy.units.core.Unit)
-
Expand source code
def formatted(self, unit: u.Unit): if unit == u.dimensionless_unscaled: return "[--]" else: return fr"[{unit.to_string(self.unit_format)}]"
def normalize_axes(self)
-
Expand source code
def normalize_axes(self): self.axes.set_xscale(self.xscale) self.axes.set_yscale(self.yscale) self.axes.margins(self.margins)
class Plot1D (table: CTable, axes: matplotlib.axes._axes.Axes = None, xscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', yscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', margins: float = 0.0, unit_format: Literal['latex', 'cds', None] = 'latex', maxdepth: float = 1000000.0, data: List[str] = None, x_axis: Unit("AU") = None, labels: bool = True, cmap: Union[matplotlib.colors.Colormap, str] = None, plot_kwargs: dict = <factory>)
-
Plot1D(table: diskchef.engine.ctable.CTable, axes: matplotlib.axes._axes.Axes = None, xscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', yscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', margins: float = 0.0, unit_format: Literal['latex', 'cds', None] = 'latex', maxdepth: float = 1000000.0, data: List[str] = None, x_axis: Unit("AU") = None, labels: bool = True, cmap: Union[matplotlib.colors.Colormap, str] = None, plot_kwargs: dict =
) Expand source code
@dataclass class Plot1D(Plot): data: List[str] = None x_axis: u.au = None yscale: Union[Literal["linear", "log", "symlog", "logit"], matplotlib.scale.ScaleBase] = "log" labels: bool = True cmap: Union[matplotlib.colors.Colormap, str] = None plot_kwargs: dict = field(default_factory=dict) def __post_init__(self): super().__post_init__() if not self.data: raise ValueError("List of `data` arguments must be specified") if self.x_axis is None: if self.table.is_in_zr_regular_grid: self.x_axis = u.Quantity(sorted(set(self.table.r))) else: self.x_axis = np.geomspace(np.min(self.table.r), np.max(self.table), 100) self.x_unit = self.x_axis.unit self.y_unit = (self.table[self.data[0]][0] * self.x_axis[0]).cgs.unit self.normalize_axes() self.axes.set_xlabel(f"Radius {self.formatted(self.x_unit)}") self.axes.set_ylabel(f"{self.formatted(self.y_unit)}") for colname in self.data: data_to_plot = self.table.column_density(colname, self.x_axis).cgs self.axes.plot(self.x_axis, data_to_plot, label=from_string(colname), **self.plot_kwargs) self.axes.legend()
Ancestors
Class variables
var cmap : Union[matplotlib.colors.Colormap, str]
var data : List[str]
var labels : bool
var plot_kwargs : dict
var x_axis : Unit("AU")
var yscale : Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase]
class Plot2D (table: CTable, axes: matplotlib.axes._axes.Axes = None, xscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'log', yscale: Union[Literal['linear', 'log', 'symlog', 'logit'], matplotlib.scale.ScaleBase] = 'linear', margins: float = 0.0, unit_format: Literal['latex', 'cds', None] = 'latex', maxdepth: float = 1000000.0, data1: str = None, data2: str = None, x_axis: str = 'Radius', y_axis: str = 'Height to radius', norm: matplotlib.colors.Normalize = None, colorbar: bool = True, labels: bool = True, cmap: Union[matplotlib.colors.Colormap, str] = None, multiply_by: Union[str, float] = 1.0, levels: astropy.units.quantity.Quantity = None, desired_max: astropy.units.quantity.Quantity = None, norm_lower: bool = False)
-
2D visualization of a disk
Expand source code
@dataclass class Plot2D(Plot): """2D visualization of a disk""" data1: str = None data2: str = None x_axis: str = "Radius" y_axis: str = "Height to radius" norm: matplotlib.colors.Normalize = None colorbar: bool = True labels: bool = True cmap: Union[matplotlib.colors.Colormap, str] = None multiply_by: Union[str, float] = 1. levels: u.Quantity = None desired_max: u.Quantity = None norm_lower: bool = False def __post_init__(self): super().__post_init__() if self.norm is None: self.norm = LogNormMaxOrders(vmax=self.desired_max, maxdepth=self.maxdepth) try: self.multiply_by = self.table[self.multiply_by] except (KeyError, ValueError): pass self.table.check_zeros(self.data1) self.table.check_zeros(self.data2) data1_q = u.Quantity(self.table[self.data1] * self.multiply_by) data1 = data1_q.value if self.norm_lower: data2_q = u.Quantity(self.table[self.data2] * self.multiply_by) self.data_unit = data2_q.unit data2 = data2_q.value self.norm(data2) else: self.data_unit = data1_q.unit self.norm(data1) x_axis = self.table[self.x_axis].value y_axis = self.table[self.y_axis].value self.x_unit = self.table[self.x_axis].unit self.y_unit = self.table[self.y_axis].unit self.axes.set_xlabel(f"{self.x_axis} {self.formatted(self.x_unit)}") self.axes.set_ylabel(f"{self.y_axis} {self.formatted(self.y_unit)}") if self.levels is None: minlevel = np.round(np.log10(self.norm.vmin)) maxlevel = np.round(np.log10(self.norm.vmax)) if maxlevel - minlevel == 1: self.cbar_formatter = LogFormatterMathtext() minlevel = np.round(np.log10(self.norm.vmin), 1) maxlevel = np.round(np.log10(self.norm.vmax), 1) else: self.cbar_formatter = LogFormatterMathtext() # as we need to convert self.levels to correct units and leave dimensionless # noinspection PyTypeChecker self.levels = np.logspace(minlevel, maxlevel, 13) else: self.levels = self.levels.to_value(self.data_unit) if len(set(self.levels)) == 1: self.levels = self.levels[0] * np.array([0.5, 1., 2.]) self.normalize_axes() im = self.axes.tricontourf( x_axis, y_axis, data1, levels=self.levels, norm=self.norm, extend="both", cmap=self.cmap, ) if self.data2 is not None: data2 = u.Quantity(self.table[self.data2] * self.multiply_by).to_value(self.data_unit) self.axes.tricontourf( x_axis, -y_axis, data2, levels=self.levels, norm=self.norm, extend="both", cmap=self.cmap, ) self.axes.axhline(0, color="black") try: formatter = matplotlib.ticker.FuncFormatter(lambda x, pos: f'{abs(x):.2f}') # todo as function def self.axes.yaxis.set_major_formatter(formatter) except ValueError: self.logger.warning("Could not fix yticks for negatives: %s", self.axes.get_yticklabels()) self.axes.margins(self.margins) if self.colorbar: im.set_clim(self.norm.vmin, self.norm.vmax) self.cbar = self.axes.figure.colorbar( im, ax=self.axes, format=self.cbar_formatter ) self.cbar.set_label(self.formatted(data1_q.unit), rotation="horizontal") self.cbar.ax.minorticks_off() if self.labels: txt = self.axes.text( 0.05, 0.05, from_string(self.data2), transform=self.axes.transAxes, verticalalignment='bottom', bbox=dict( boxstyle="round", ec=(1., 1., 1., 0.5), fc=(1., 1., 1., 0.7), ) ) if self.data2 is not None: txt = self.axes.text( 0.05, 0.95, from_string(self.data1), transform=self.axes.transAxes, verticalalignment='top', bbox=dict( boxstyle="round", ec=(1., 1., 1., 0.5), fc=(1., 1., 1., 0.7), ) ) def contours( self, data: str, levels: Union[u.Quantity, List[float]], x_axis: str = "Radius", y_axis: str = "Height to radius", clabel_kwargs: dict = None, colors: Union[str, List[str]] = "black", on_colorbar: bool = True, location: Literal["upper", "bottom", "both"] = "both", **kwargs ): if clabel_kwargs is None: clabel_kwargs = {} data_q = u.Quantity(self.table[data]) data = data_q.value dataunit = data_q.unit if "fmt" not in clabel_kwargs.keys(): clabel_kwargs["fmt"] = f"%d {dataunit.to_string(self.unit_format)}" x_axis = self.table[x_axis].to_value(self.x_unit) y_axis = self.table[y_axis].to_value(self.y_unit) if location == "both": x_axis = [*x_axis, *x_axis] y_axis = [*-y_axis, *y_axis] data = [*data, *data] elif location == "bottom": y_axis = -y_axis elif location == "upper": pass else: raise CHEFValueError('location bust be one of ["upper", "bottom", "both"]') conts = self.axes.tricontour( x_axis, y_axis, data, levels=levels.to_value(dataunit), colors=colors, **kwargs ) if on_colorbar: try: levels_as_data = levels.to_value(self.data_unit) new_conts = copy.copy(conts) new_conts.levels = levels_as_data self.cbar.add_lines(new_conts) except u.core.UnitConversionError as e: self.logger.info(e) try: conts.clabel(levels.to_value(dataunit), use_clabeltext=True, inline=True, inline_spacing=1, **clabel_kwargs) except ValueError as e: self.logger.warning(e)
Ancestors
Class variables
var cmap : Union[matplotlib.colors.Colormap, str]
var colorbar : bool
var data1 : str
var data2 : str
var desired_max : astropy.units.quantity.Quantity
var labels : bool
var levels : astropy.units.quantity.Quantity
var multiply_by : Union[str, float]
var norm : matplotlib.colors.Normalize
var norm_lower : bool
var x_axis : str
var y_axis : str
Methods
def contours(self, data: str, levels: Union[astropy.units.quantity.Quantity, List[float]], x_axis: str = 'Radius', y_axis: str = 'Height to radius', clabel_kwargs: dict = None, colors: Union[str, List[str]] = 'black', on_colorbar: bool = True, location: Literal['upper', 'bottom', 'both'] = 'both', **kwargs)
-
Expand source code
def contours( self, data: str, levels: Union[u.Quantity, List[float]], x_axis: str = "Radius", y_axis: str = "Height to radius", clabel_kwargs: dict = None, colors: Union[str, List[str]] = "black", on_colorbar: bool = True, location: Literal["upper", "bottom", "both"] = "both", **kwargs ): if clabel_kwargs is None: clabel_kwargs = {} data_q = u.Quantity(self.table[data]) data = data_q.value dataunit = data_q.unit if "fmt" not in clabel_kwargs.keys(): clabel_kwargs["fmt"] = f"%d {dataunit.to_string(self.unit_format)}" x_axis = self.table[x_axis].to_value(self.x_unit) y_axis = self.table[y_axis].to_value(self.y_unit) if location == "both": x_axis = [*x_axis, *x_axis] y_axis = [*-y_axis, *y_axis] data = [*data, *data] elif location == "bottom": y_axis = -y_axis elif location == "upper": pass else: raise CHEFValueError('location bust be one of ["upper", "bottom", "both"]') conts = self.axes.tricontour( x_axis, y_axis, data, levels=levels.to_value(dataunit), colors=colors, **kwargs ) if on_colorbar: try: levels_as_data = levels.to_value(self.data_unit) new_conts = copy.copy(conts) new_conts.levels = levels_as_data self.cbar.add_lines(new_conts) except u.core.UnitConversionError as e: self.logger.info(e) try: conts.clabel(levels.to_value(dataunit), use_clabeltext=True, inline=True, inline_spacing=1, **clabel_kwargs) except ValueError as e: self.logger.warning(e)