Pattern Matching in Python

@proofit404

Ocaml

The Hard Way

                        
# let x = 1;;

# x;;
- : int = 1
                        
                    

Even harder

                        
# let add x y = x + y;;

# add 1 2;;
- : int = 3
                        
                    

How far we can go?

                        
# let rec factorial x =
  if x = 1
  then 1
  else x * factorial (x - 1);;

# factorial 5;;
- : int = 120
                        
                    

Algebraic Data Types

List

                        
# type list = Cons * list | Nil;;

# let mylist = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Nil)))));;

# let mylist = 1 :: 2 :: 3 :: 4 :: 5 :: [];;

# let mylist = [1; 2; 3; 4; 5];;
                        
                    

Fold right

                        
# let rec foldr f initial l =
    match l with
    | [] -> initial
    | head :: tail -> f head (foldr f initial tail);;

# foldr ( + ) 0 [1; 2; 3; 4; 5];;
- : int = 15
# foldr ( * ) 1 [1; 2; 3; 4; 5];;
- : int = 120
                        
                    

Tree

                        
# type tree = Leaf of int
            | Node of tree * tree;;

# let mytree = Node(Node(Leaf(0), Leaf(2)), Node(Leaf(3), Leaf(4)));;
                        
                    

Predicate search

                        
# let rec have_in_tree test tree =
    match tree with
    | Leaf x -> test x
    | Node (left, right) ->
       have_in_tree test left
       || have_in_tree test right;;

# have_in_tree (fun x -> x = 0) mytree;;
- : bool = true
                        
                    

Let's do it in Python!

So how it looks like?

                        
from patterns import match

@match(1)
def test(x):
    return 'one'

@match(2)
def test(x):
    return 'two'

>>> test(1)
'one'
>>> test(2)
'two'
                        
                    

Match decorator

                        
def match(*expressions):

    def decorator(f):
        if f.__name__ not in f.__globals__:
            def wrapper(*args):
                ...
        else:
            wrapper = f.__globals__[f.__name__]
        ...
        return wrapper

    return decorator
                        
                    

Store signatures

                        
def wrapper(*args):
    ...

wrapper.signatures = []

signature = Signature(*expressions)

wrapper.signatures.append((signature, f))
                        
                    

Signature

                        
class Signature:

    def __init__(self, *args):

        self.args = args

    def __eq__(self, args):

        return args == self.args

    def arguments(self, args):

        return args
                        
                    

Analyze call

                        
def wrapper(*args):
    for signature, f in wrapper.signatures:
        if signature == args:
            return f(*signature.arguments(args))
                        
                    

Introduce expressions

Expressions

                        
from patterns import match, x, _

@match(x > 5)
def greater_test(x):
    return '%s greater then five' % x

@match(_)
def greater_test(x):
    return '%s less then or equal to five' % x

>>> greater_test(7)
'7 greater then five'
>>> greater_test(1)
'1 less then or equal to five'
                        
                    

Import hacks

                        
class module(types.ModuleType):

    def __getattr__(self, name):

        if name == 'match':
            return match
        else:
            return Expression(name)

old_module = sys.modules['patterns']

new_module = sys.modules['patterns'] = module('patterns')
                        
                    

Expression

                        
class Expression:

    def __init__(self, name):

        self.name = name

    def __gt__(self, other):

        return GreaterPredicate(self.name, other)
                        
                    

Predicate

                        
class Predicate:

    def __init__(self, name, *args):

        self.name = name
        self.args = args

class GreaterPredicate(Predicate):

    def __eq__(self, other):

        return (isinstance(other, type(self.args[0])) and
                other > self.args[0])
                        
                    

List pattern matching

Fold right

                        
from patterns import match, l, _

@match(_, _, [])
def foldr(f, initial, seq):
    return initial

@match(_, _, l[0], l[1:])
def foldr(f, initial, head, tail):
    return f(head, foldr(f, initial, tail))

>>> foldr(lambda x, y: x + y, 0, [1, 2, 3, 4, 5])
15
>>> foldr(lambda x, y: x * y, 1, [1, 2, 3, 4, 5])
120
                        
                    

Expression

                        
class Expression:

    def __getitem__(self, item):

        if isinstance(item, int):
            return IndexPredicate(self.name, item)
        elif isinstance(item, slice):
            return SlicePredicate(self.name, item)
                        
                    

Predicates

                        
class IndexPredicate(Predicate):

    pass

class SlicePredicate(Predicate):

    pass

class HeadTailPredicate(Predicate):

    def __eq__(self, other):

        return (isinstance(other, list) and
                len(other) >= self.args[0] + 1)
                        
                    

Signature equality

                        
class Signature:

    def __eq__(self, args):

       signature = ()
       for key, predicates in itertools.groupby(self.args):
           ...
           if (len(predicates) == 2 and
               isinstance(predicates[0], IndexPredicate) and
               isinstance(predicates[1], SlicePredicate)):
                   item = predicates[0].args[0]
                   signature += (HeadTailPredicate(key, item),)
           ...
       return args == signature
                        
                    

Signature call

                        
class Signature:

    def arguments(self, args):

        call = ()
        for key, predicates in itertools.groupby(self.args):
            arg = next(args)
            ...
            if isinstance(predicate, HeadTailPredicate):
                item = predicate.args[0]
                call += (arg[item], arg[item + 1:])
            ...
        return call
                        
                    

Don't Try This at Home!

Thanks!