configurator.py 6.83 KB
Newer Older
1 2
""" Helpers and utilities for use by codes_configurator and friends """

3 4 5 6 7 8
#
# Copyright (C) 2013 University of Chicago.
# See COPYRIGHT notice in top-level directory.
#
#

9 10 11 12 13 14 15 16 17
import re
import os
import imp
from collections import Sequence

class configurator:
    # note: object assumes input has been validated by caller, though it will
    #       still throw if it finds duplicate keys in the dict
    def __init__(self, module, replace_pairs):
18 19 20 21 22 23
        # checks - check cfields and friends
        check_cfields(module)
        # check that pairs are actually pairs
        if len(replace_pairs) % 2 != 0:
            raise ValueError("token pairs must come in twos")

24 25 26 27 28 29
        self.mod = module
        self.num_fields = len(self.mod.cfields)
        self.labels = [k[0] for k in self.mod.cfields] + replace_pairs[0::2]
        self.replace_map = { k[0] : None for k in self.mod.cfields }
        self.start_iter = False
        self.in_iter    = False
30
        self.has_except = "excepts" in self.mod.__dict__
31 32
        self.has_derived = "cfields_derived_labels" in self.mod.__dict__ and \
                           "cfields_derived" in self.mod.__dict__
33 34 35 36 37 38 39 40 41 42 43 44

        for i in range(0, len(replace_pairs), 2):
            k,vstr = replace_pairs[i], replace_pairs[i+1]
            # add pair to replace map and labels (but not iterables!)
            if k in self.replace_map:
                raise ValueError("token " + k + " of token_pairs matches a token "
                        "given by the substitution py file")
            # try making v numeric - fall back to string
            v = try_num(vstr)
            if v == None:
                v = vstr
            self.replace_map[k] = v
45 46 47 48 49
            
        # initialize derived labels if necessary
        if self.has_derived:
            self.labels += [l for l in self.mod.cfields_derived_labels]

50 51 52 53 54 55 56 57 58 59 60

    def __iter__(self):
        self.start_iter = True 
        self.in_iter = True
        return self
    def next(self):
        if not self.in_iter:
            # either uninitialized or at the end of the road
            raise StopIteration 
        elif self.start_iter:
            # first iteration - initialize the iterators
61
            self.iterables = [k[1].__iter__() for k in self.mod.cfields]
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            for i in range(0, self.num_fields):
                v = self.iterables[i].next()
                self.replace_map[self.labels[i]] = v
            self.start_iter = False
        else:
            # > first iteration, perform the updates
            # generate the next config
            for i in range(0,self.num_fields):
                try:
                    # update current iterable and finish
                    v = self.iterables[i].next()
                    self.replace_map[self.labels[i]] = v
                    break
                except StopIteration:
                    # reset the current iterable and set to first element
                    self.iterables[i] = self.mod.cfields[i][1].__iter__()
                    v = self.iterables[i].next()
                    self.replace_map[self.labels[i]] = v
            else:
                # last iterable has finished, have generated full set
                raise StopIteration 
83 84 85 86
        # add derived fields before exceptions 
        if self.has_derived:
            self.mod.cfields_derived(self.replace_map)
        # check if this is a valid config, if not, then recurse
87 88 89 90 91
        if self.has_except and is_replace_except(self.mod.excepts,
                self.replace_map):
            return self.next()
        else:
            return None
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

    def write_header(self, fout):
        fout.write("# format:\n# <config index>")
        for l in self.labels:
            fout.write(" <" + l + ">")
        fout.write("\n")

    def write_config_line(self, ident, fout):
        # print the configuration to the log
        fout.write(str(ident))
        for l in self.labels:
            fout.write(' ' + str(self.replace_map[l]))
        else:
            fout.write('\n')

107 108 109 110 111 112 113 114 115
def is_replace_except(except_map, replace_map):
    for d in except_map:
        for k in d:
            if d[k] != replace_map[k]:
                break
        else:
            return True
    return False

116 117 118 119 120 121 122 123 124 125
# checks - make sure cfields is set and is the correct type 
def check_cfields(module):
    if "cfields" not in module.__dict__:
        raise TypeError("Expected cfields to be defined in " + str(module))
    elif not \
            (isinstance(module.cfields, Sequence) and \
            isinstance(module.cfields[0][0], str) and \
            isinstance(module.cfields[0][1], Sequence)):
        raise TypeError("cfields in incorrect format, see usage")

126 127 128 129 130
    if "excepts" in module.__dict__ and not \
            (isinstance(module.excepts, Sequence) and\
            isinstance(module.excepts[0], dict)) :
        raise TypeError("excepts not in correct format, see usage")

131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    dl = "cfields_derived_labels" in module.__dict__
    d  = "cfields_derived" in module.__dict__
    if (dl and not d) or (not dl and d):
        raise TypeError("both cfields_derived_labels and cfields_derived must "
                "be set")
    elif dl and d and not \
            (isinstance(module.cfields_derived_labels, Sequence) and \
            isinstance(module.cfields_derived_labels[0], str) and \
            hasattr(module.cfields_derived, "__call__")):
        raise TypeError("cfields_derived_labels must be a sequence of "
                "strings, cfields_derived must be callable (accepting a "
                "dict of replace_token, replacement pairs and adding pairs "
                "for each label in cfields_derived_labels")
        

146

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
# import a python file (assumes there is a .py suffix!!!)
def import_from(filename):
    path,name = os.path.split(filename)
    name,ext  = os.path.splitext(name) 

    fmod, fname, data = imp.find_module(name, [path])
    return imp.load_module(name,fmod,fname,data)

# attempts casting to a numeric type, returning None on failure
def try_num(s):
    try:
        return int(s)
    except ValueError:
        try:
            return float(s)
        except ValueError:
            return None

# given a string and a set of substitution pairs, 
#   perform a single pass replace 
# st is the string to perform substitution on
# kv_pairs is a dict or a sequence of sequences.
# kv_pairs examples:
# { a:b, c:d }
# [[a,b],[c,d]]
# [(a,b),(c,d)]
# ((a,b),(c,d))
def replace_many(st, kv_pairs):
    rep_dict = {}
    # dict-ify and string-ify the input
    if isinstance(kv_pairs, dict):
        for k in kv_pairs:
            rep_dict[k] = str(kv_pairs[k])
    elif isinstance(kv_pairs, Sequence) and isinstance(kv_pairs[0], Sequence):
        for k in kv_pairs:
            rep_dict[k[0]] = str(kv_pairs[k[1]])
    else:
        raise TypeError("Expected dict or sequence of sequences types")

    pat = re.compile("|".join([re.escape(k) for k in rep_dict]))

    return pat.sub(lambda match: rep_dict[match.group(0)], st)