Skip to content
Snippets Groups Projects
tsplot 15.36 KiB
#!/usr/bin/env python3
#coding: utf-8

import argparse
import json
import numpy as np
import sys
import re
import logging
import matplotlib as M
import matplotlib.pyplot as P
import numbers
from functools import reduce
from distutils.version import LooseVersion
from itertools import chain, islice, cycle

# Run-time check for matplot lib version for line style functionality.
if LooseVersion(M.__version__)<LooseVersion("1.5.0"):
    logging.critical("require matplotlib version ≥ 1.5.0")
    sys.exit(1)

# Read timeseries data from multiple files, plot each in one panel, with common
# time axis, and optionally sharing a vertical axis as governed by the configuration.

def parse_clargs():
    def float_or_none(s):
        try: return float(s)
        except ValueError: return None

    def parse_range_spec(s):
        l, r = (float_or_none(x) for x in s.split(','))
        return (l,r)

    def parse_colour_spec(s):
        colour, tests = s.split(':',1)
        tests = tests.split(',')
        return colour, tests

    P = argparse.ArgumentParser(description='Plot time series data on one or more graphs.')
    P.add_argument('inputs', metavar='FILE', nargs='+',
                   help='time series data in JSON format')
    P.add_argument('-A', '--abscissa', metavar='AXIS', dest='axis',
                   help='use values from AXIS instead of \'time\' as abscissa')
    P.add_argument('-t', '--trange', metavar='RANGE', dest='trange',
                   type=parse_range_spec,
                   help='restrict time axis to RANGE (see below)')
    P.add_argument('-g', '--group', metavar='KEY,...', dest='groupby',
                   type=lambda s: s.split(','),
                   help='plot series with same KEYs on the same axes')
    P.add_argument('-s', '--select', metavar='EXPR,...', dest='select',
                   type=lambda s: s.split(','),
                   action='append',
                   help='select only series matching EXPR')
    P.add_argument('-c', '--colour', metavar='COLOUR:EXPR,...', dest='colours',
                   type=parse_colour_spec,
                   action='append',
                   help='use colour COLOUR a base for series matching EXPR')
    P.add_argument('-o', '--output', metavar='FILE', dest='outfile',
                   help='save plot to file FILE')
    P.add_argument('-l', '--list', action='store_true',
                   help='list selected time-series')
    P.add_argument('--dpi', metavar='NUM', dest='dpi',
                   type=int,
                   help='set dpi for output image')
    P.add_argument('--scale', metavar='NUM', dest='scale',
                   type=float,
                   help='scale size of output image by NUM')
    P.add_argument('-x', '--exclude', metavar='NUM', dest='exclude',
                   type=float,
                   help='remove extreme points outside NUM times the 0.9-interquantile range of the median')

    P.epilog =  'A range is specifed by a pair of floating point numbers min,max where '
    P.epilog += 'either may be omitted to indicate the minimum or maximum of the corresponding '
    P.epilog += 'values in the data.'
    P.epilog += '\n'
    P.epilog += 'Filter expressions are of the form KEY=VALUE. (Might add other ops later.)'

    # modify args to avoid argparse having a fit when it encounters an option
    # argument of the form '<negative number>,...'

    argsbis = [' '+a if re.match(r'-[\d.]',a) else a for a in sys.argv[1:]]
    return P.parse_args(argsbis)

def isstring(s):
    return isinstance(s,str)

def take(n, s):
    return islice((i for i in s), 0, n)

class TimeSeries:
    def __init__(self, ts, ys, **kwargs):
        self.t = np.array(ts)
        n = self.t.shape[0]

        self.y = np.full_like(self.t, np.nan)
        ny = min(len(ys), len(self.y))
        self.y[:ny] = ys[:ny]

        self.meta = dict(kwargs)
        self.ex_ts = None

    def trestrict(self, bounds):
        clip = range_meet(self.trange(), bounds)
        self.t = np.ma.masked_outside(self.t, v1=clip[0], v2=clip[1])
        self.y = np.ma.masked_array(self.y, mask=self.t.mask)

    def exclude_outliers(self, iqr_factor):
        yfinite = np.ma.masked_invalid(self.y).compressed()
        l_, lq, median, uq, u_ = np.percentile(yfinite, [0, 5.0, 50.0, 95.0, 100])
        lb = median - iqr_factor*(uq-lq)
        ub = median + iqr_factor*(uq-lq)

        np_err_save = np.seterr(all='ignore')
        yex = np.ma.masked_where(np.isfinite(self.y)&(self.y<=ub)&(self.y>=lb), self.y)
        np.seterr(**np_err_save)

        tex = np.ma.masked_array(self.t, mask=yex.mask)
        self.ex_ts = TimeSeries(tex.compressed(), yex.compressed())
        self.ex_ts.meta = dict(self.meta)

        self.y = np.ma.filled(np.ma.masked_array(self.y, mask=~yex.mask), np.nan)

    def excluded(self):
        return self.ex_ts

    def name(self):
        return self.meta.get('name',"")   # value of 'name' attribute in source

    def label(self):
        return self.meta.get('label',"")  # name of column in source

    def units(self):
        return self.meta.get('units',"")

    def trange(self):
        return self.t.min(), self.t.max()

    def yrange(self):
        return self.y.min(), self.y.max()

def run_select(expr, v):
    m = re.match(r'([^=>!<~]+)(>=|<=|>|<|!=|=|!~|~)(.*)', expr)
    if not m:
        return True

    key, op, test = m.groups()
    if not key in v:
        return False

    val = v[key]
    if op=='~':
        return test in str(val)
    elif op=='!~':
        return test not in str(val)
    else:
        if isinstance(val, numbers.Number):
            if re.match(r'true$', test, re.I):
                test=True
            elif re.match(r'false$', test, re.I):
                test=False
            else:
                try:
                    test=int(test)
                except ValueError:
                    test=float(test)

        if op=='=':
            return val==test
        elif op=='!=':
            return val!=test
        elif op=='<':
            return val<test
        elif op=='>':
            return val>test
        elif op=='<=':
            return val<=test
        elif op=='>=':
            return val>=test
        else:
            return False


def read_json_timeseries(j, axis='time', select=[]):
    # Convention:
    #
    # Time series data is represented by an object with a subobject 'data' and optionally
    # other key/value entries comprising metadata for the time series.
    #
    # The 'data' object holds one array of numbers 'time' and zero or more other
    # numeric arrays of sample values corresponding to the values in 'time'. The
    # names of these other arrays are taken to be the labels for the plots.
    #
    # Units can be specified by a top level entry 'units' which is either a string
    # (units are common for all data series in the object) or by a map that
    # takes a label to a unit string.

    # If given a list instead of a hash, collect time series from each entry.

    ts_list = []
    if isinstance(j, list):
        for o in j:
            ts_list.extend(read_json_timeseries(o, axis, select))
        return ts_list

    try:
        jdata = j['data']
        ncol = len(jdata)

        times = jdata[axis]
        nsample = len(times)
    except KeyError:
        # This wasn't a time series after all.
        return ts_list

    def units(label):
        try:
           unitmap = j['units']
           if isstring(unitmap):
               return unitmap
           else:
               return unitmap[label]
        except:
           return ""

    i = 1
    for key in list(jdata.keys()):
        if key==axis: continue

        meta = j.copy()
        meta.update({'label': key, 'data': None, 'units': units(key)})

        del meta['data']

        if not select or any([all([run_select(s, meta) for s in term]) for term in select]):
            ts_list.append(TimeSeries(times, jdata[key], **meta))

    return ts_list

def min_(a,b):
    if a is None:
        return b
    elif b is None:
        return a
    else:
        return min(a,b)

def range_join(r, s):
    return (min_(r[0], s[0]), max(r[1], s[1]))

def range_meet(r, s):
    return (max(r[0], s[0]), min_(r[1], s[1]))


class PlotData:
    def __init__(self, key_label=""):
        self.series = []
        self.group_key_label = key_label

    def trange(self):
        return reduce(range_join, [s.trange() for s in self.series])

    def yrange(self):
        return reduce(range_join, [s.yrange() for s in self.series])

    def name(self):
        return reduce(lambda n, s: n or s.name(), self.series, "")

    def group_label(self):
        return self.group_key_label

    def unique_labels(self, formatter=lambda x: x):
        # attempt to create unique labels for plots in the group based on
        # meta data
        labels = [s.label() for s in self.series]
        if len(labels)<2:
            return labels

        n = len(labels)
        keyset = reduce(lambda k, s: k.union(list(s.meta.keys())), self.series, set())
        keyi = iter(keyset)
        try:
            while len(set(labels)) != n:
                k = next(keyi)
                if k=='label':
                    continue

                vs = [s.meta.get(k,None) for s in self.series]
                if len(set(vs))==1:
                    continue

                for i in range(n):
                    prefix = '' if k=='name' else k+'='
                    if vs[i] is not None:
                        labels[i] += ', '+k+'='+str(formatter(vs[i]))

        except StopIteration:
            pass

        return labels

# Input: list of TimeSeries objects; collection of metadata keys to group on
# Return list of plot info (time series, data extents, metadata), one per plot.

def gather_ts_plots(tss, groupby):
    group_lookup = {}
    plot_groups = []

    for ts in tss:
        key = tuple([ts.meta.get(g) for g in groupby])
        if key is () or None in key or key not in group_lookup:
            pretty_key=', '.join([str(k)+'='+str(v) for k,v in zip(groupby, key) if v is not None])
            pd = PlotData(pretty_key)
            pd.series = [ts]
            plot_groups.append(pd)
            group_lookup[key] = len(plot_groups)-1
        else:
            plot_groups[group_lookup[key]].series.append(ts)

    return plot_groups


def make_palette(cm_name, n, cmin=0, cmax=1):
    smap = M.cm.ScalarMappable(M.colors.Normalize(cmin/float(cmin-cmax),(cmin-1)/float(cmin-cmax)),
                               M.cm.get_cmap(cm_name))
    return [smap.to_rgba((2*i+1)/float(2*n)) for i in range(n)]

def round_numeric_(x):
    # Helper to round numbers in labels
    if not isinstance(x,float): return x
    return "{:6g}".format(x)

def plot_plots(plot_groups, axis='time', colour_overrides=[], save=None, dpi=None, scale=None):
    nplots = len(plot_groups)
    plot_groups = sorted(plot_groups, key=lambda g: g.group_label())

    # use same global time scale for all plots
    trange = reduce(range_join, [g.trange() for g in plot_groups])

    # use group names for titles?
    group_titles = any((g.group_label() for g in plot_groups))

    figure = P.figure()
    for i in range(nplots):
        group = plot_groups[i]
        plot = figure.add_subplot(nplots, 1, i+1)

        title = group.group_label() if group_titles else group.name()
        plot.set_title(title)

        # y-axis label: use timeseries label and units if only
        # one series in group, otherwise use a legend with labels,
        # and units alone on the axes. At most two different unit
        # axes can be drawn.

        def ylabel(unit):
            if len(group.series)==1:
                lab = group.series[0].label()
                if unit:
                    lab += ' (' + unit + ')'
            else:
                lab = unit

            return lab

        uniq_units = list(set([s.units() for s in group.series]))
        uniq_units.sort()
        if len(uniq_units)>2:
            logging.warning('more than two different units on the same plot')
            uniq_units = uniq_units[:2]

        # store each series in a slot corresponding to one of the units,
        # together with a best-effort label

        series_by_unit = [[] for i in range(len(uniq_units))]
        unique_labels = group.unique_labels(formatter=round_numeric_)

        for si in range(len(group.series)):
            s = group.series[si]
            label = unique_labels[si]
            try:
                series_by_unit[uniq_units.index(s.units())].append((s,label))
            except ValueError:
                pass

        # TODO: need to find a scheme of colour/line allocation that is
        # double y-axis AND greyscale friendly.

        palette = \
            [make_palette(cm, n, 0, 0.5) for
                cm, n in zip(['hot', 'winter'],  [len(x) for x in series_by_unit])]

        lines = cycle(["-",(0,(3,1))])

        first_plot = True
        for ui in range(len(uniq_units)):
            if not first_plot:
                plot = plot.twinx()

            axis_color = palette[ui][0]
            plot.set_ylabel(ylabel(uniq_units[ui]), color=axis_color)
            for l in plot.get_yticklabels():
                l.set_color(axis_color)

            plot.get_yaxis().get_major_formatter().set_useOffset(False)
            plot.get_yaxis().set_major_locator(M.ticker.MaxNLocator(nbins=6))

            plot.set_xlim(trange)

            colours = cycle(palette[ui])
            line = next(lines)
            for s, l in series_by_unit[ui]:
                c = next(colours)
                for colour, tests in colour_overrides:
                    if all([run_select(t, s.meta) for t in tests]):
                        c = colour

                plot.plot(s.t, s.y, color=c, ls=line, label=l)
                # treat exluded points especially
                ex = s.excluded()
                if ex is not None:
                    ymin, ymax = s.yrange()
                    plot.plot(ex.t, np.clip(ex.y, ymin, ymax), marker='x', ls='', color=c)

            if first_plot:
                plot.legend(loc=2, fontsize='small')
                plot.grid()
            else:
                plot.legend(loc=1, fontsize='small')

            first_plot = False

    # adapted from http://stackoverflow.com/questions/6963035
    axis_ymin = min([ax.get_position().ymin for ax in figure.axes])
    figure.text(0.5, axis_ymin - float(3)/figure.dpi, axis, ha='center', va='center')
    if save:
        if scale:
            base = figure.get_size_inches()
            figure.set_size_inches((base[0]*scale, base[1]*scale))

        figure.savefig(save, dpi=dpi)
    else:
        P.show()

args = parse_clargs()
tss = []
axis = args.axis if args.axis else 'time'
for filename in args.inputs:
    select = args.select
    with open(filename) as f:
        j = json.load(f)
        tss.extend(read_json_timeseries(j, axis, select))

if args.list:
    for ts in tss:
        print('name:', ts.meta['name'])
        print('label:', ts.meta['label'])
        for k in [x for x in sorted(ts.meta.keys()) if x not in ['name', 'label']]:
            print(k+':', ts.meta[k])
        print()

else:
    if args.trange:
        for ts in tss:
            ts.trestrict(args.trange)

    if args.exclude:
        for ts in tss:
            ts.exclude_outliers(args.exclude)

    groupby = args.groupby if args.groupby else []
    plots = gather_ts_plots(tss, groupby)

    if not args.outfile:
        M.interactive(False)

    colours = args.colours if args.colours else []
    plot_plots(plots, axis=axis, colour_overrides=colours, save=args.outfile, dpi=args.dpi, scale=args.scale)