Source code for skypy.pipeline._pipeline

"""Pipeline implementation.

This module provides methods to run pipelines of functions with dependencies
and handle their results.
"""

from astropy.table import Table, Column
from copy import copy, deepcopy
from collections.abc import Sequence, Mapping
from ._config import load_skypy_yaml
from ._items import Item, Call, Ref
from . import log
import networkx
import pathlib


__all__ = [
    'Pipeline',
]


[docs]class Pipeline: r'''Class for running pipelines. This is the main class for running pipelines of functions with dependencies and using their results to generate variables and tables. '''
[docs] @classmethod def read(cls, filename): '''Read a pipeline from a configuration file. Parameters ---------- filename : str The name of the configuration file. ''' config = load_skypy_yaml(filename) return cls(config)
def __init__(self, configuration): '''Construct the pipeline. Parameters ---------- configuration : dict-like Configuration for the pipeline. Notes ----- Each step in the pipeline is configured by a dictionary specifying a variable name and the associated value. If a function argument is a string `$variable_name`, it refers to the values of previous step in the pipeline. 'configuration' should contain the name and configuration of each variable and/or an entry named 'tables'. 'tables' should contain a set of nested dictionaries, first containing the name of each table, then the name and configuration of each column and optionally an entry named 'init' with a configuration that initialises the table. If 'init' is not specificed the table will be initialised as an empty astropy Table by default. See [1]_ for examples of pipeline configurations in YAML format. References ---------- .. [1] https://github.com/skypyproject/skypy/tree/main/examples ''' # config contains settings for all variables and table initialisation # table_config contains settings for all table columns self.config = deepcopy(configuration) self.cosmology = self.config.pop('cosmology', None) self.parameters = self.config.pop('parameters', {}) self.table_config = self.config.pop('tables', {}) default_table = Call(Table) self.config.update({k: v.pop('.init', default_table) for k, v in self.table_config.items()}) # Initalise state with parameters self.state = copy(self.parameters) # Create a Directed Acyclic Graph of all jobs and dependencies self.dag = networkx.DiGraph() # context for function calls context = {} # use cosmology in global context if given if self.cosmology is not None: context['cosmology'] = Ref('cosmology') # - add nodes for each variable, table and column # - add edges for the table dependencies # - keep track where items need to be called items = {} for job, settings in self.config.items(): self.dag.add_node(job, skip=False) if isinstance(settings, Item): items[job] = settings for table, columns in self.table_config.items(): table_complete = '.'.join((table, 'complete')) self.dag.add_node(table_complete) self.dag.add_edge(table, table_complete) for column, settings in columns.items(): job = '.'.join((table, column)) self.dag.add_node(job, skip=False) self.dag.add_edge(table, job) self.dag.add_edge(job, table_complete) if isinstance(settings, Item): items[job] = settings # DAG nodes for individual columns in multi-column assignment names = [n.strip() for n in column.split(',')] if len(names) > 1: for name in names: subjob = '.'.join((table, name)) self.dag.add_node(subjob) self.dag.add_edge(job, subjob) # go through items and add edges for all dependencies for job, settings in items.items(): # get dependencies from item deps = settings.depend(self) # add edges for dependencies for d in deps: # job depends on d self.dag.add_edge(d, job) # recurse dependencies such that d = 'a.b.c' -> 'a.b' -> 'a' c = d.rpartition('.')[0] while c: self.dag.add_edge(c, d) c, d = c.rpartition('.')[0], c # infer additional item properties from context settings.infer(context)
[docs] def execute(self, parameters={}): r'''Run a pipeline. This function runs a pipeline of functions to generate variables and the columns of a set of tables. It uses a Directed Acyclic Graph to determine a non-blocking order of execution that resolves any dependencies, see [1]_. Parameters ---------- parameters : dict Updated parameter values for this execution. References ---------- .. [1] https://networkx.github.io/documentation/stable/ ''' # update parameter state self.parameters.update(parameters) # initialise state object self.state = copy(self.parameters) # Initialise cosmology from config parameters if self.cosmology is not None: log.info("Setting cosmology") self.state['cosmology'] = self.evaluate(self.cosmology) # go through the jobs in dependency order for job in networkx.topological_sort(self.dag): node = self.dag.nodes[job] skip = node.get('skip', True) if skip: continue log.info(f"Generating {job}") if job in self.config: settings = self.config.get(job) self.state[job] = self.evaluate(settings) else: table, column = job.split('.') settings = self.table_config[table][column] names = [n.strip() for n in column.split(',')] if len(names) > 1: # Multi-column assignment t = Table(self.evaluate(settings), names=names) self.state[table].add_columns(t.columns) else: # Single column assignment self.state[table][column] = self.evaluate(settings)
[docs] def write(self, filename, overwrite=False): r'''Write pipeline results to disk. Parameters ---------- filename : str Name of output file to be written. It must have one of the supported file extensions for FITS (.fit .fits .fts) or HDF5 (.hdf5 .hd5 .he5 .h5). overwrite : bool If filename already exists, this flag indicates whether or not to overwrite it (without warning). ''' suffix = pathlib.Path(filename).suffix.lower() _fits_suffixes = ('.fit', '.fits', '.fts') _hdf5_suffixes = ('.hdf5', '.hd5', '.he5', '.h5') if suffix in _fits_suffixes: self.write_fits(filename, overwrite) elif suffix in _hdf5_suffixes: self.write_hdf5(filename, overwrite) else: raise ValueError(f'{suffix} is an unsupported file format. SkyPy supports ' 'FITS (' + ' '.join(_fits_suffixes) + ') and ' 'HDF5 (' + ' '.join(_hdf5_suffixes) + ').')
[docs] def write_fits(self, filename, overwrite=False): r'''Write pipeline results to a FITS file. Parameters ---------- filename : str Name of output file to be written. overwrite : bool If filename already exists, this flag indicates whether or not to overwrite it (without warning). ''' from astropy.io.fits import HDUList, PrimaryHDU, table_to_hdu hdul = [PrimaryHDU()] for t in self.table_config: hdu = table_to_hdu(self[t]) hdu.header['EXTNAME'] = t hdul.append(hdu) HDUList(hdul).writeto(filename, overwrite=overwrite)
[docs] def write_hdf5(self, filename, overwrite=False): r'''Write pipeline results to a HDF5 file. Parameters ---------- filename : str Name of output file to be written. overwrite : bool If filename already exists, this flag indicates whether or not to overwrite it (without warning). ''' for t in self.table_config: self[t].write(filename, path=f'tables/{t}', append=True, overwrite=overwrite)
[docs] def evaluate(self, value): '''evaluate an item in the pipeline''' if isinstance(value, Sequence) and not isinstance(value, str): # recurse lists return [self.evaluate(v) for v in value] elif isinstance(value, Mapping): # recurse dicts return {k: self.evaluate(v) for k, v in value.items()} elif isinstance(value, Item): # evaluate item return value.evaluate(self) else: # everything else return unchanged return value
[docs] def depend(self, args): '''get dependencies from function args returns a list of all references found ''' if isinstance(args, Sequence) and not isinstance(args, str): # recurse list return sum([self.depend(a) for a in args], []) elif isinstance(args, Mapping): # get explicit dependencies deps = args.pop('.depends', []) # turn a single value into a list if isinstance(deps, str) or not isinstance(deps, Sequence): deps = [deps] # recurse remaining kwargs return deps + sum([self.depend(a) for a in args.values()], []) elif isinstance(args, Item): # check pipeline item return args.depend(self) else: # no reference return []
def __getitem__(self, label): item = self.state name = None while label: key, _, label = label.partition('.') name = f'{name}.{key}' if name else key try: item = item[key] except KeyError as e: raise KeyError('unknown label: ' + name) from e if isinstance(item, Column): return item.data if item.unit is None else item.quantity else: return item