add ~sjm1 to version string
[debian/digraphtools.git] / digraphtools / predicate.py
1 #! /usr/bin/env python
2
3 '''predicates that can be chained together with boolean expressions before evaluation
4
5 e.g.
6         a = predicate(lambda s: 'a' in s)
7         b = predicate(lambda s: 'b' in s)
8         c = predicate(lambda s: 'c' in s)
9
10         anyof = a | b | c
11         allof = a & b & c
12         not_anyof = anyof != True
13         assert anyof('--a--')
14         assert allof('-abc-')
15         assert not_anyof('12345')
16
17 Also, generate predicates such as above from strings
18
19         pf = PredicateContainsFactory()
20         anyof2 = pf.predicate_from_string('a | b | c')
21         pallof2 = pf.predicate_from_string('a & b & c')
22         not_anyof2 = pf.predicate_from_string('!(a & b & c)')
23         assert anyof2('--a--')
24         assert allof2('-abc-')
25         assert not_anyof2('12345')
26
27 These can be very useful for filtering of dependency graphs
28 '''
29
30 import operator
31 import re
32
33 def defer(origfunc,*argfs,**argfd):
34         '''defer execution of the arguments of a function
35         given origfunc return a function such that the code
36                 f = defer(origfunc, arga, key=argb)
37                 f(*newargs, **newargd)
38         is equivalent to:
39                 origfunc(arga(*newargs,**newargd), key=argb(*newargs,**newargd))
40         '''
41         def wrapper(*args, **argd):
42                 newargs = [argf(*args, **argd) for argf in argfs]
43                 newargd = dict((k,argf(*args, **argd)) for k,argf in argfd.items())
44                 return origfunc(*newargs, **newargd)
45         wrapper.origfunc = origfunc
46         wrapper.argfs = argfs
47         wrapper.argfd = argfd
48         wrapper.__repr__ = lambda s: "defer <%s>( *(%s) **(%s)) " % (repr(origfunc),repr(argfs),repr(argfd))
49         return wrapper
50
51 def always(val):
52         '''returns a function that always returns val regardless of inputs'''
53         def alwaysf(*args, **argd): return val
54         alwaysf.val = val
55         return alwaysf
56
57 class predicate(object):
58         '''chainable predicates
59         e.g.
60                 a = predicate(lambda s: 'a' in s)
61                 b = predicate(lambda s: 'b' in s)
62                 c = predicate(lambda s: 'c' in s)
63
64                 anyof = a | b | c
65                 allof = a & b & c
66                 not_anyof = anyof != True
67                 assert anyof('--a--')
68                 assert allof('-abc-')
69                 assert not_anyof('12345')
70         '''
71
72         def __init__(self, func):
73                 self.func = func
74         def __call__(self, arg):
75                 return self.func(arg)
76         def __and__(self,other):
77                 return self.__defer_infix__(other,operator.__and__)
78         def __or__(self,other):
79                 return self.__defer_infix__(other,operator.__or__)
80         def __ne__(self,other):
81                 return self.__defer_infix__(other,operator.__ne__)
82         def __defer_infix__(self,other,op):
83                 if isinstance(other, bool): 
84                         other = always(other)
85                 elif not isinstance(other, predicate): 
86                         return NotImplemented
87                 return self.__class__(defer(op, self, other))
88         def __repr__(self):
89                 return 'pred( '+repr(self.func)+' )'
90
91 class notp(predicate):
92         '''exactly the same as a predicate but inverts it's __call__ output'''
93         def __call__(self, *args, **argd):
94                 return not predicate.__call__(self, *args, **argd)
95
96
97
98 def partition_list(items, partition):
99         '''works like str.partition but for lists
100         e.g. partition(['aa','bb','cd','ee'],'cd') == ['aa','bb'],'cd',['ee']
101              partition(['aa','bb','cd','ee'],'ff') == ['aa','bb','cd','ee'],None,[]
102         '''
103         for i,obj in enumerate(items):
104                 if obj == partition:
105                         return items[:i],obj,items[i+1:]
106         return items,None,[]
107
108 class ParseSyntaxError(Exception): pass
109 class LexParse(object):
110         '''very simple lexer/parser'''
111         class _leaf(object):
112                 def __init__(self, data): self.data = data
113                 def __repr__(self): return '_leaf(%s)' %(repr(self.data))
114
115         valid_tokens = ['(',')','!','&','|']
116
117         def _match_bracket(self, tokens, i, bopen='(',bclose=')'):
118                 '''find the closing bracket that matches an open bracket
119                 return None if there is no matching bracket
120                 otherwise the index into tokens of the close bracket that matches the opening bracket at position i
121                 '''
122                 assert i < len(tokens)
123                 assert tokens[i] == bopen
124                 depth = 0
125                 for i in xrange(i,len(tokens)):
126                         tok = tokens[i]
127                         if tok == bopen: 
128                                 depth += 1
129                         elif tok == bclose:
130                                 depth -= 1
131                                 if depth < 0: return None
132                                 if depth == 0: return i
133                 return None
134                 
135         def lex(self, s):
136                 '''returns a list of tokens from a string
137                 tokens returned are anything inside self.valid_tokens or
138                 any other string not containing tokens, stripped
139                 of leading and trailing whitespace
140                 '''
141                 s = s.strip()
142                 if s == '': return []
143                 for tok in self.valid_tokens:
144                         l,t,r = s.partition(tok)
145                         if t==tok: return self.lex(l)+[tok]+self.lex(r)
146                 return [self._leaf(s)]
147
148         def parse(self, tokens):
149                 '''parse a list of tokens in order of predicence and return the output'''
150                 if len(tokens) == 0:
151                         raise ParseSyntaxError('Cannot parse empty subexpression')
152                 # Brackets
153                 l,part,r = partition_list(tokens, '(')
154                 if part != None:
155                         if ')' in l: raise ParseSyntaxError('unmatched ) near',tokens)
156                         r.insert(0,'(')
157                         rindex = self._match_bracket(r, 0)
158                         if rindex is None: raise ParseSyntaxError('unmatched ( near',tokens)
159                         assert r[rindex] == ')'
160                         inner = r[1:rindex]
161                         r = r[rindex+1:]
162                         inner = self.brackets(self.parse(inner))
163                         return self.parse(l+[inner]+r)
164
165                 # unary not
166                 if tokens[0] == '!':
167                         if len(tokens) < 2: raise ParseSyntaxError('syntax error near',tokens)
168                         # this only works without other unary operators
169                         if tokens[1] in self.valid_tokens: raise ParseSyntaxError('syntax error near', tokens)
170                         argument = self.parse([ tokens[1] ])
171                         inv = self.notx(argument)
172                         return self.parse([inv]+tokens[2:])
173
174                 # and
175                 l,part,r = partition_list(tokens, '&')
176                 if part != None:
177                         if not len(l) or not len(r):
178                                 raise ParseSyntaxError('syntax error near', tokens)
179                         l,r = self.parse(l), self.parse(r)
180                         return self.andx(l,r)
181
182                 # or
183                 l,part,r = partition_list(tokens, '|')
184                 if part != None:
185                         if not len(l) or not len(r):
186                                 raise ParseSyntaxError('syntax error near', tokens)
187                         l,r = self.parse(l), self.parse(r)
188                         return self.orx(l,r)
189
190                 if len(tokens) == 1:
191                         if isinstance(tokens[0], self._leaf):
192                                 return self.data(tokens[0].data) # base case
193                         elif tokens[0] in self.valid_tokens:
194                                 raise ParseSyntaxError('syntax error near',tokens)
195                         return tokens[0] # Already parsed
196
197                 # Nothing else is sane
198                 print repr(tokens)
199                 raise ParseSyntaxError('syntax error near', tokens)
200
201         def brackets(self, expr): 
202                 '''You almost never want to override this'''
203                 return expr 
204
205         def notx(self, expr): pass
206         def andx(self, expr_l, expr_r): pass
207         def orx(self, expr_l, expr_r): pass
208         def data(self, data): pass
209
210 class BoolParse(LexParse):
211         '''example parser implementation
212         bp = BoolParse()
213         assert False or (False and not (True or False)) == False
214         inp = 'False | (False & ! (True | False))'
215         assert bp.parse(bp.lex(inp)) is False
216         '''
217         notx = lambda s,expr: not expr
218         andx = lambda s,l,r: l and r
219         orx = lambda s,l,r: l or r
220         def data(self,data):
221                 return not(data.lower() == 'false' or data == '0')
222
223
224 class PredicateContainsFactory(LexParse):
225         '''create predicates that act on the contents of a container passed to them'''
226         def predicate_from_string(self, definition):
227                 tokens = self.lex(definition)
228                 return self.parse(tokens)
229         def notx(self, pred):
230                 return notp(pred)
231         def andx(self, pred_l, pred_r):
232                 return pred_l & pred_r
233         def orx(self, pred_l, pred_r):
234                 return pred_l | pred_r
235         def data(self, data):
236                 return predicate(lambda container: data in container)
237
238 if __name__ == "__main__":
239         def defer_sample():
240                 def a(arga, moo=None, argb=None):
241                         return arga+argb
242                 def b(arga, moo=None, argb=None):
243                         return arga^argb
244                 def c(arga, moo=None, argb=None):
245                         return arga,argb
246                 ooer = defer(c, a,argb=b)
247                 result = ooer(1234,argb=4312)
248                 assert result == (5546, 5130)
249
250         def predicate_sample():
251                 a = predicate(lambda s: 'a' in s)
252                 b = predicate(lambda s: 'b' in s)
253                 c = predicate(lambda s: 'c' in s)
254                 d = predicate(lambda s: 'd' in s)
255
256                 anyof = a | b | c
257                 allof = a & b & c
258
259                 not_anyof = anyof != True
260                 not_allof = allof != True
261
262
263                 assert anyof('asdf')
264                 assert allof('abc')
265                 assert not anyof('1234')
266                 assert not allof('ab')
267                 assert not_anyof('1234')
268                 assert not_allof('1234')
269
270                 nottest = a & b & notp( c | d )
271                 assert nottest('ab')
272                 assert not nottest('abc')
273                 assert not nottest('b')
274                 assert not nottest('d')
275                 assert not nottest('bd')
276                 assert not nottest('abd')
277
278                 e = predicate(lambda n: n%2==0)
279                 t = predicate(lambda n: n%3==0)
280                 eset = set(filter(e, range(1000)))
281                 tset = set(filter(t, range(1000)))
282                 eutset = set(filter(e|t, range(1000)))
283                 eitset = set(filter(e&t, range(1000)))
284                 assert eutset == eset.union(tset)
285                 assert eitset == eset.intersection(tset)
286
287         def parser_internal_test():
288                 lp = LexParse()
289                 #lp._match_bracket(self, tokens, i, bopen='(',bclose=')'):
290                 assert lp._match_bracket('()',0) == 1
291                 assert lp._match_bracket('(',0) == None
292                 assert lp._match_bracket(')))))()))))',5) == 6
293                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',0) == 29
294                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',2) == 3 
295                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',12) == 25
296                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',23) == 24
297                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',23) == 24
298                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',26) == 27
299                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',9) == 28
300                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',10) == 11
301                 assert lp._match_bracket('((()(()))(()((()(()()))())()))',16) == 21
302
303         def parser_sample():
304                 bp = BoolParse()
305                 assert False or (False and not (True or False)) == False
306                 inp = 'False | (False & ! (True | False))'
307                 assert bp.parse(bp.lex(inp)) is False
308                 assert bp.parse(bp.lex('true & !false'))
309
310         def predicate_factory_sample():
311                 pf = PredicateContainsFactory()
312                 pred = pf.predicate_from_string('fish & !cow')
313                 assert pred(['fish', 'bat', 'pidgeon'])
314                 assert not pred( ['fish', 'cow', 'bat'] )
315                 assert not pred( [] )
316                 assert not pred( ['cow'] )
317                 assert not pred( ['bat','pig'] )
318
319                 a = predicate(lambda s: 'a' in s)
320                 b = predicate(lambda s: 'b' in s)
321                 c = predicate(lambda s: 'c' in s)
322                 anyof2 = pf.predicate_from_string('a | b | c')
323                 allof2 = pf.predicate_from_string('a & b & c')
324                 not_anyof2 = pf.predicate_from_string('!(a & b & c)')
325                 assert anyof2('--a--')
326                 assert allof2('-abc-')
327                 assert not_anyof2('12345')
328
329                 pred = pf.predicate_from_string('( a | b | c ) & ( c | e | d )')
330                 assert not pred('b')
331                 assert pred('c')
332                 assert pred('cd')
333                 assert pred('acd')
334                 assert not pred('ab')
335                 assert not pred('a')
336
337         parser_internal_test()
338         defer_sample()
339         predicate_sample()
340         parser_sample()
341         predicate_factory_sample()