Module alchemic_chemgraph.snapshot

Expand source code
import math
import re
import random

from matplotlib import pyplot as plt
from matplotlib import cm, colors
import numpy as np
import networkx
import plotly.graph_objects as go

from alchemic_chemgraph.reaction import Reaction

random.seed(246)
np.random.seed(4812)


class Snapshot(object):
    """Class to handle ALCHEMIC reaction graph"""

    def __init__(self, entry: str, layout=None):
        lines = entry.strip().splitlines()
        self.title = lines[1]
        self.time = float(re.findall(r'at\s*(\S*)\s*\[years\]', self.title)[0])
        self.mainspec = lines[1].lstrip("#").split(maxsplit=1)[0]
        self.reactions = []
        self.reactionsdict = {}
        for line in lines[3:]:
            reaction = Reaction(line)
            self.reactions.append(reaction)
            self.reactionsdict[reaction.reaction_id] = reaction
        self._write_all_species()
        self._write_graph()
        self.graph_positions(layout=layout)
        self._reaction_colormap()

    def __repr__(self):
        return (self.title)

    def _write_all_species_set(self):
        """Get species from self.reactions and write them to self.species"""
        self.species = {}
        for reaction in self.reactions:
            self.species.update({*reaction.reactants, *reaction.products})
            # self.species.add(*reaction.products)

    def _write_all_species(self):
        """Get species from self.reactions and write them to self.species"""
        self.species = []
        for reaction in self.reactions:
            self.species.extend([*reaction.reactants, *reaction.products])
        self.species = list(dict.fromkeys(self.species))

        # self.species.add(*reaction.products)

    def _write_graph(self):
        """Generate a networkx graph based on self.reactions"""
        self.graph = networkx.DiGraph()
        self.graph.add_nodes_from(
            self.species,
            color='#aaaaff',
            type='species'
        )

        for reaction in self.reactions:
            kw = {
                'adiff': reaction.adiff, 'rdiff': reaction.rdiff,
                'rdiff_nonabs': reaction.rdiff_nonabs,
                'weight': 1 / np.sqrt(np.abs(reaction.rdiff))
            }
            self.graph.add_node(
                reaction.reaction_id,
                color='#ffff00',
                type='reaction',
                **kw,
            )
            for reactant in reaction.reactants:
                self.graph.add_edge(reactant, reaction.reaction_id, **kw)
            for product in reaction.products:
                self.graph.add_edge(reaction.reaction_id, product, **kw)

    def graph_positions(self, layout=None, all_reactions=None, all_species=None):
        """Return positions of graph nodes in networkx pos format"""
        if layout == 'shell' or layout is None:
            if all_species is None:
                all_species = self.species
            if all_reactions is None:
                all_reactions = [reaction.reaction_id for reaction in self.reactions]
            species_without_main = all_species[:]
            species_without_main.remove(self.mainspec)
            pos = networkx.shell_layout(
                self.graph,
                [
                    [self.mainspec],
                    all_reactions,
                    species_without_main,
                ]
            )
        elif layout == 'kamada_kawai':
            pos = networkx.kamada_kawai_layout(self.graph)
        else:
            raise ValueError
        self.pos = pos
        return pos

    def _reaction_colormap(self, amplitude=0.3, cmap='RdBu_r'):
        """Generate colormap for reactions and write it to self.reaction_colormap"""
        norm = colors.Normalize(vmin=-amplitude, vmax=amplitude)
        self.reaction_colormap = lambda x: colors.to_hex(cm.get_cmap(cmap)(norm(x)))

    def show(self, method='shell'):
        """Plot graph"""
        adiffs = np.array([self.graph[u][v]['adiff'] for u, v in self.graph.edges])
        widths = np.abs(adiffs)
        widths = widths / max(widths)

        colors = [self.graph.nodes[node]['color'] for node in self.graph.nodes]
        networkx.draw_networkx(
            self.graph, self.pos,
            with_labels=True,
            arrows=True,
            width=widths * 5,
            node_color=colors,
        )
        # networkx.draw_networkx_edge_labels(
        #     self.graph, pos,
        #     {(u, v): f"{(self.graph[u][v]['rdiff'] * 100):.0f}%" for u, v in self.graph.edges}
        # )
        plt.show()

    def edges(
            self, arrow_length=0.07, arrow_angle=0.3,
            threshold=0.1
    ):
        G = self.graph
        edge_traces = []
        for reaction in self.reactions:
            edge_x = []
            edge_y = []
            edge_text = []
            extra_kw = {}
            for edge in [*G.out_edges(reaction.reaction_id), *G.in_edges(reaction.reaction_id)]:
                x0, y0 = self.pos[edge[0]]
                x1, y1 = self.pos[edge[1]]
                fi = math.atan2(y1 - y0, x1 - x0)
                edge_x.append(x0)
                edge_x.append(x1)
                edge_x.append(x1 - arrow_length * math.cos(fi - arrow_angle))
                edge_x.append(x1)
                edge_x.append(x1 - arrow_length * math.cos(fi + arrow_angle))
                edge_x.append(None)
                edge_y.append(y0)
                edge_y.append(y1)
                edge_y.append(y1 - arrow_length * math.sin(fi - arrow_angle))
                edge_y.append(y1)
                edge_y.append(y1 - arrow_length * math.sin(fi + arrow_angle))
                edge_y.append(None)
                edge_text.extend(6 * [f"{G[edge[0]][edge[1]]['rdiff'] * 100:.1f}%", ])
            color = self.reaction_colormap(G[edge[0]][edge[1]]['rdiff'])
            if abs(G[edge[0]][edge[1]]['rdiff']) < threshold:
                extra_kw['visible'] = 'legendonly'
            edge_trace = go.Scatter(
                name=reaction.desc,
                x=edge_x, y=edge_y,
                text=edge_text,
                line=dict(
                    width=3,
                    color=color,
                    # colorscale='RdBu',
                    # reversescale=True,
                    # cmid=0.,

                ),
                hoverinfo='text',
                mode='lines',
                **extra_kw
            )
            edge_traces.append(edge_trace)
        return edge_traces

    def nodes(
            self,
            nodetype='all', mode='markers',
            hoverinfo='text', size=10, color=(),
            marker='circle'
    ):
        G = self.graph
        node_x = []
        node_y = []
        node_text = []
        extrakw = dict()
        if nodetype == 'reactions':
            node_color = []
        else:
            node_color = color

        for node in G.nodes():
            text = ''
            if nodetype == 'all':
                pass
            elif nodetype == 'reactions':
                if node not in self.reactionsdict.keys():
                    continue
                text = self.reactionsdict[node].desc
                extrakw = dict(
                    colorbar=dict(
                        thickness=15,
                        title='Reaction relative rate',
                        xanchor='right',
                        titleside='right',
                    ),
                    colorscale='RdBu',
                    reversescale=True,
                    cmid=0.,
                    cmax=0.3,
                    cmin=-0.3,
                    showscale=True,
                )
                node_color.append(G.nodes.data()[node]['rdiff_nonabs'])
            elif nodetype == 'species':
                if node not in self.species:
                    continue
                text = node
            else:
                raise ValueError("nodetype must be one of: ['all', 'reactions', 'species']")
            x, y = self.pos[node]
            node_x.append(x)
            node_y.append(y)
            node_text.append(text)

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode=mode,
            hoverinfo=hoverinfo,
            text=node_text,
            marker_symbol=marker,
            marker=dict(
                color=node_color,
                size=size,
                sizemode='area',
                sizeref=2. * size / (1. ** 2),
                line_width=2,
                **extrakw
            ),
            name=nodetype,
        )
        return node_trace

    def plotly_traces(self):
        """Generate plotly traces"""
        edge_traces = self.edges()
        node_reactions_trace = self.nodes(nodetype='reactions', color='#ffffaa', size=10)
        node_species_trace = self.nodes(
            nodetype='species', color='#aaaaff', size=50, mode='markers+text',
            marker='square'
        )
        return [*edge_traces, node_reactions_trace, node_species_trace]

    def plotly(self):
        """Generate plotly figure"""
        fig = go.Figure(
            data=self.plotly_traces(),
            layout=go.Layout(
                title=self.title,
                titlefont_size=16,
                showlegend=True,
                hovermode='closest',
                margin=dict(b=20, l=5, r=5, t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor='x'))
        )
        return fig

Classes

class Snapshot (entry, layout=None)

Class to handle ALCHEMIC reaction graph

Expand source code
class Snapshot(object):
    """Class to handle ALCHEMIC reaction graph"""

    def __init__(self, entry: str, layout=None):
        lines = entry.strip().splitlines()
        self.title = lines[1]
        self.time = float(re.findall(r'at\s*(\S*)\s*\[years\]', self.title)[0])
        self.mainspec = lines[1].lstrip("#").split(maxsplit=1)[0]
        self.reactions = []
        self.reactionsdict = {}
        for line in lines[3:]:
            reaction = Reaction(line)
            self.reactions.append(reaction)
            self.reactionsdict[reaction.reaction_id] = reaction
        self._write_all_species()
        self._write_graph()
        self.graph_positions(layout=layout)
        self._reaction_colormap()

    def __repr__(self):
        return (self.title)

    def _write_all_species_set(self):
        """Get species from self.reactions and write them to self.species"""
        self.species = {}
        for reaction in self.reactions:
            self.species.update({*reaction.reactants, *reaction.products})
            # self.species.add(*reaction.products)

    def _write_all_species(self):
        """Get species from self.reactions and write them to self.species"""
        self.species = []
        for reaction in self.reactions:
            self.species.extend([*reaction.reactants, *reaction.products])
        self.species = list(dict.fromkeys(self.species))

        # self.species.add(*reaction.products)

    def _write_graph(self):
        """Generate a networkx graph based on self.reactions"""
        self.graph = networkx.DiGraph()
        self.graph.add_nodes_from(
            self.species,
            color='#aaaaff',
            type='species'
        )

        for reaction in self.reactions:
            kw = {
                'adiff': reaction.adiff, 'rdiff': reaction.rdiff,
                'rdiff_nonabs': reaction.rdiff_nonabs,
                'weight': 1 / np.sqrt(np.abs(reaction.rdiff))
            }
            self.graph.add_node(
                reaction.reaction_id,
                color='#ffff00',
                type='reaction',
                **kw,
            )
            for reactant in reaction.reactants:
                self.graph.add_edge(reactant, reaction.reaction_id, **kw)
            for product in reaction.products:
                self.graph.add_edge(reaction.reaction_id, product, **kw)

    def graph_positions(self, layout=None, all_reactions=None, all_species=None):
        """Return positions of graph nodes in networkx pos format"""
        if layout == 'shell' or layout is None:
            if all_species is None:
                all_species = self.species
            if all_reactions is None:
                all_reactions = [reaction.reaction_id for reaction in self.reactions]
            species_without_main = all_species[:]
            species_without_main.remove(self.mainspec)
            pos = networkx.shell_layout(
                self.graph,
                [
                    [self.mainspec],
                    all_reactions,
                    species_without_main,
                ]
            )
        elif layout == 'kamada_kawai':
            pos = networkx.kamada_kawai_layout(self.graph)
        else:
            raise ValueError
        self.pos = pos
        return pos

    def _reaction_colormap(self, amplitude=0.3, cmap='RdBu_r'):
        """Generate colormap for reactions and write it to self.reaction_colormap"""
        norm = colors.Normalize(vmin=-amplitude, vmax=amplitude)
        self.reaction_colormap = lambda x: colors.to_hex(cm.get_cmap(cmap)(norm(x)))

    def show(self, method='shell'):
        """Plot graph"""
        adiffs = np.array([self.graph[u][v]['adiff'] for u, v in self.graph.edges])
        widths = np.abs(adiffs)
        widths = widths / max(widths)

        colors = [self.graph.nodes[node]['color'] for node in self.graph.nodes]
        networkx.draw_networkx(
            self.graph, self.pos,
            with_labels=True,
            arrows=True,
            width=widths * 5,
            node_color=colors,
        )
        # networkx.draw_networkx_edge_labels(
        #     self.graph, pos,
        #     {(u, v): f"{(self.graph[u][v]['rdiff'] * 100):.0f}%" for u, v in self.graph.edges}
        # )
        plt.show()

    def edges(
            self, arrow_length=0.07, arrow_angle=0.3,
            threshold=0.1
    ):
        G = self.graph
        edge_traces = []
        for reaction in self.reactions:
            edge_x = []
            edge_y = []
            edge_text = []
            extra_kw = {}
            for edge in [*G.out_edges(reaction.reaction_id), *G.in_edges(reaction.reaction_id)]:
                x0, y0 = self.pos[edge[0]]
                x1, y1 = self.pos[edge[1]]
                fi = math.atan2(y1 - y0, x1 - x0)
                edge_x.append(x0)
                edge_x.append(x1)
                edge_x.append(x1 - arrow_length * math.cos(fi - arrow_angle))
                edge_x.append(x1)
                edge_x.append(x1 - arrow_length * math.cos(fi + arrow_angle))
                edge_x.append(None)
                edge_y.append(y0)
                edge_y.append(y1)
                edge_y.append(y1 - arrow_length * math.sin(fi - arrow_angle))
                edge_y.append(y1)
                edge_y.append(y1 - arrow_length * math.sin(fi + arrow_angle))
                edge_y.append(None)
                edge_text.extend(6 * [f"{G[edge[0]][edge[1]]['rdiff'] * 100:.1f}%", ])
            color = self.reaction_colormap(G[edge[0]][edge[1]]['rdiff'])
            if abs(G[edge[0]][edge[1]]['rdiff']) < threshold:
                extra_kw['visible'] = 'legendonly'
            edge_trace = go.Scatter(
                name=reaction.desc,
                x=edge_x, y=edge_y,
                text=edge_text,
                line=dict(
                    width=3,
                    color=color,
                    # colorscale='RdBu',
                    # reversescale=True,
                    # cmid=0.,

                ),
                hoverinfo='text',
                mode='lines',
                **extra_kw
            )
            edge_traces.append(edge_trace)
        return edge_traces

    def nodes(
            self,
            nodetype='all', mode='markers',
            hoverinfo='text', size=10, color=(),
            marker='circle'
    ):
        G = self.graph
        node_x = []
        node_y = []
        node_text = []
        extrakw = dict()
        if nodetype == 'reactions':
            node_color = []
        else:
            node_color = color

        for node in G.nodes():
            text = ''
            if nodetype == 'all':
                pass
            elif nodetype == 'reactions':
                if node not in self.reactionsdict.keys():
                    continue
                text = self.reactionsdict[node].desc
                extrakw = dict(
                    colorbar=dict(
                        thickness=15,
                        title='Reaction relative rate',
                        xanchor='right',
                        titleside='right',
                    ),
                    colorscale='RdBu',
                    reversescale=True,
                    cmid=0.,
                    cmax=0.3,
                    cmin=-0.3,
                    showscale=True,
                )
                node_color.append(G.nodes.data()[node]['rdiff_nonabs'])
            elif nodetype == 'species':
                if node not in self.species:
                    continue
                text = node
            else:
                raise ValueError("nodetype must be one of: ['all', 'reactions', 'species']")
            x, y = self.pos[node]
            node_x.append(x)
            node_y.append(y)
            node_text.append(text)

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode=mode,
            hoverinfo=hoverinfo,
            text=node_text,
            marker_symbol=marker,
            marker=dict(
                color=node_color,
                size=size,
                sizemode='area',
                sizeref=2. * size / (1. ** 2),
                line_width=2,
                **extrakw
            ),
            name=nodetype,
        )
        return node_trace

    def plotly_traces(self):
        """Generate plotly traces"""
        edge_traces = self.edges()
        node_reactions_trace = self.nodes(nodetype='reactions', color='#ffffaa', size=10)
        node_species_trace = self.nodes(
            nodetype='species', color='#aaaaff', size=50, mode='markers+text',
            marker='square'
        )
        return [*edge_traces, node_reactions_trace, node_species_trace]

    def plotly(self):
        """Generate plotly figure"""
        fig = go.Figure(
            data=self.plotly_traces(),
            layout=go.Layout(
                title=self.title,
                titlefont_size=16,
                showlegend=True,
                hovermode='closest',
                margin=dict(b=20, l=5, r=5, t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor='x'))
        )
        return fig

Methods

def edges(self, arrow_length=0.07, arrow_angle=0.3, threshold=0.1)
Expand source code
def edges(
        self, arrow_length=0.07, arrow_angle=0.3,
        threshold=0.1
):
    G = self.graph
    edge_traces = []
    for reaction in self.reactions:
        edge_x = []
        edge_y = []
        edge_text = []
        extra_kw = {}
        for edge in [*G.out_edges(reaction.reaction_id), *G.in_edges(reaction.reaction_id)]:
            x0, y0 = self.pos[edge[0]]
            x1, y1 = self.pos[edge[1]]
            fi = math.atan2(y1 - y0, x1 - x0)
            edge_x.append(x0)
            edge_x.append(x1)
            edge_x.append(x1 - arrow_length * math.cos(fi - arrow_angle))
            edge_x.append(x1)
            edge_x.append(x1 - arrow_length * math.cos(fi + arrow_angle))
            edge_x.append(None)
            edge_y.append(y0)
            edge_y.append(y1)
            edge_y.append(y1 - arrow_length * math.sin(fi - arrow_angle))
            edge_y.append(y1)
            edge_y.append(y1 - arrow_length * math.sin(fi + arrow_angle))
            edge_y.append(None)
            edge_text.extend(6 * [f"{G[edge[0]][edge[1]]['rdiff'] * 100:.1f}%", ])
        color = self.reaction_colormap(G[edge[0]][edge[1]]['rdiff'])
        if abs(G[edge[0]][edge[1]]['rdiff']) < threshold:
            extra_kw['visible'] = 'legendonly'
        edge_trace = go.Scatter(
            name=reaction.desc,
            x=edge_x, y=edge_y,
            text=edge_text,
            line=dict(
                width=3,
                color=color,
                # colorscale='RdBu',
                # reversescale=True,
                # cmid=0.,

            ),
            hoverinfo='text',
            mode='lines',
            **extra_kw
        )
        edge_traces.append(edge_trace)
    return edge_traces
def graph_positions(self, layout=None, all_reactions=None, all_species=None)

Return positions of graph nodes in networkx pos format

Expand source code
def graph_positions(self, layout=None, all_reactions=None, all_species=None):
    """Return positions of graph nodes in networkx pos format"""
    if layout == 'shell' or layout is None:
        if all_species is None:
            all_species = self.species
        if all_reactions is None:
            all_reactions = [reaction.reaction_id for reaction in self.reactions]
        species_without_main = all_species[:]
        species_without_main.remove(self.mainspec)
        pos = networkx.shell_layout(
            self.graph,
            [
                [self.mainspec],
                all_reactions,
                species_without_main,
            ]
        )
    elif layout == 'kamada_kawai':
        pos = networkx.kamada_kawai_layout(self.graph)
    else:
        raise ValueError
    self.pos = pos
    return pos
def nodes(self, nodetype='all', mode='markers', hoverinfo='text', size=10, color=(), marker='circle')
Expand source code
def nodes(
        self,
        nodetype='all', mode='markers',
        hoverinfo='text', size=10, color=(),
        marker='circle'
):
    G = self.graph
    node_x = []
    node_y = []
    node_text = []
    extrakw = dict()
    if nodetype == 'reactions':
        node_color = []
    else:
        node_color = color

    for node in G.nodes():
        text = ''
        if nodetype == 'all':
            pass
        elif nodetype == 'reactions':
            if node not in self.reactionsdict.keys():
                continue
            text = self.reactionsdict[node].desc
            extrakw = dict(
                colorbar=dict(
                    thickness=15,
                    title='Reaction relative rate',
                    xanchor='right',
                    titleside='right',
                ),
                colorscale='RdBu',
                reversescale=True,
                cmid=0.,
                cmax=0.3,
                cmin=-0.3,
                showscale=True,
            )
            node_color.append(G.nodes.data()[node]['rdiff_nonabs'])
        elif nodetype == 'species':
            if node not in self.species:
                continue
            text = node
        else:
            raise ValueError("nodetype must be one of: ['all', 'reactions', 'species']")
        x, y = self.pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(text)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode=mode,
        hoverinfo=hoverinfo,
        text=node_text,
        marker_symbol=marker,
        marker=dict(
            color=node_color,
            size=size,
            sizemode='area',
            sizeref=2. * size / (1. ** 2),
            line_width=2,
            **extrakw
        ),
        name=nodetype,
    )
    return node_trace
def plotly(self)

Generate plotly figure

Expand source code
def plotly(self):
    """Generate plotly figure"""
    fig = go.Figure(
        data=self.plotly_traces(),
        layout=go.Layout(
            title=self.title,
            titlefont_size=16,
            showlegend=True,
            hovermode='closest',
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor='x'))
    )
    return fig
def plotly_traces(self)

Generate plotly traces

Expand source code
def plotly_traces(self):
    """Generate plotly traces"""
    edge_traces = self.edges()
    node_reactions_trace = self.nodes(nodetype='reactions', color='#ffffaa', size=10)
    node_species_trace = self.nodes(
        nodetype='species', color='#aaaaff', size=50, mode='markers+text',
        marker='square'
    )
    return [*edge_traces, node_reactions_trace, node_species_trace]
def show(self, method='shell')

Plot graph

Expand source code
def show(self, method='shell'):
    """Plot graph"""
    adiffs = np.array([self.graph[u][v]['adiff'] for u, v in self.graph.edges])
    widths = np.abs(adiffs)
    widths = widths / max(widths)

    colors = [self.graph.nodes[node]['color'] for node in self.graph.nodes]
    networkx.draw_networkx(
        self.graph, self.pos,
        with_labels=True,
        arrows=True,
        width=widths * 5,
        node_color=colors,
    )
    # networkx.draw_networkx_edge_labels(
    #     self.graph, pos,
    #     {(u, v): f"{(self.graph[u][v]['rdiff'] * 100):.0f}%" for u, v in self.graph.edges}
    # )
    plt.show()