forked from OSchip/llvm-project
				
			
		
			
				
	
	
		
			279 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			279 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
"""Utilities for enumeration of finite and countably infinite sets.
 | 
						|
"""
 | 
						|
from __future__ import absolute_import, division, print_function
 | 
						|
###
 | 
						|
# Countable iteration
 | 
						|
 | 
						|
# Simplifies some calculations
 | 
						|
class Aleph0(int):
 | 
						|
    _singleton = None
 | 
						|
    def __new__(type):
 | 
						|
        if type._singleton is None:
 | 
						|
            type._singleton = int.__new__(type)
 | 
						|
        return type._singleton
 | 
						|
    def __repr__(self): return '<aleph0>'
 | 
						|
    def __str__(self): return 'inf'
 | 
						|
    
 | 
						|
    def __cmp__(self, b):
 | 
						|
        return 1
 | 
						|
 | 
						|
    def __sub__(self, b):
 | 
						|
        raise ValueError("Cannot subtract aleph0")
 | 
						|
    __rsub__ = __sub__
 | 
						|
 | 
						|
    def __add__(self, b): 
 | 
						|
        return self
 | 
						|
    __radd__ = __add__
 | 
						|
 | 
						|
    def __mul__(self, b): 
 | 
						|
        if b == 0: return b            
 | 
						|
        return self
 | 
						|
    __rmul__ = __mul__
 | 
						|
 | 
						|
    def __floordiv__(self, b):
 | 
						|
        if b == 0: raise ZeroDivisionError
 | 
						|
        return self
 | 
						|
    __rfloordiv__ = __floordiv__
 | 
						|
    __truediv__ = __floordiv__
 | 
						|
    __rtuediv__ = __floordiv__
 | 
						|
    __div__ = __floordiv__
 | 
						|
    __rdiv__ = __floordiv__
 | 
						|
 | 
						|
    def __pow__(self, b):
 | 
						|
        if b == 0: return 1
 | 
						|
        return self
 | 
						|
aleph0 = Aleph0()
 | 
						|
 | 
						|
def base(line):
 | 
						|
    return line*(line+1)//2
 | 
						|
 | 
						|
def pairToN(pair):
 | 
						|
    x,y = pair
 | 
						|
    line,index = x+y,y
 | 
						|
    return base(line)+index
 | 
						|
 | 
						|
def getNthPairInfo(N):
 | 
						|
    # Avoid various singularities
 | 
						|
    if N==0:
 | 
						|
        return (0,0)
 | 
						|
 | 
						|
    # Gallop to find bounds for line
 | 
						|
    line = 1
 | 
						|
    next = 2
 | 
						|
    while base(next)<=N:
 | 
						|
        line = next
 | 
						|
        next = line << 1
 | 
						|
    
 | 
						|
    # Binary search for starting line
 | 
						|
    lo = line
 | 
						|
    hi = line<<1
 | 
						|
    while lo + 1 != hi:
 | 
						|
        #assert base(lo) <= N < base(hi)
 | 
						|
        mid = (lo + hi)>>1
 | 
						|
        if base(mid)<=N:
 | 
						|
            lo = mid
 | 
						|
        else:
 | 
						|
            hi = mid
 | 
						|
 | 
						|
    line = lo
 | 
						|
    return line, N - base(line)
 | 
						|
 | 
						|
def getNthPair(N):
 | 
						|
    line,index = getNthPairInfo(N)
 | 
						|
    return (line - index, index)
 | 
						|
 | 
						|
def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
 | 
						|
    """getNthPairBounded(N, W, H) -> (x, y)
 | 
						|
    
 | 
						|
    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
 | 
						|
 | 
						|
    if W <= 0 or H <= 0:
 | 
						|
        raise ValueError("Invalid bounds")
 | 
						|
    elif N >= W*H:
 | 
						|
        raise ValueError("Invalid input (out of bounds)")
 | 
						|
 | 
						|
    # Simple case...
 | 
						|
    if W is aleph0 and H is aleph0:
 | 
						|
        return getNthPair(N)
 | 
						|
 | 
						|
    # Otherwise simplify by assuming W < H
 | 
						|
    if H < W:
 | 
						|
        x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
 | 
						|
        return y,x
 | 
						|
 | 
						|
    if useDivmod:
 | 
						|
        return N%W,N//W
 | 
						|
    else:
 | 
						|
        # Conceptually we want to slide a diagonal line across a
 | 
						|
        # rectangle. This gives more interesting results for large
 | 
						|
        # bounds than using divmod.
 | 
						|
        
 | 
						|
        # If in lower left, just return as usual
 | 
						|
        cornerSize = base(W)
 | 
						|
        if N < cornerSize:
 | 
						|
            return getNthPair(N)
 | 
						|
 | 
						|
        # Otherwise if in upper right, subtract from corner
 | 
						|
        if H is not aleph0:
 | 
						|
            M = W*H - N - 1
 | 
						|
            if M < cornerSize:
 | 
						|
                x,y = getNthPair(M)
 | 
						|
                return (W-1-x,H-1-y)
 | 
						|
 | 
						|
        # Otherwise, compile line and index from number of times we
 | 
						|
        # wrap.
 | 
						|
        N = N - cornerSize
 | 
						|
        index,offset = N%W,N//W
 | 
						|
        # p = (W-1, 1+offset) + (-1,1)*index
 | 
						|
        return (W-1-index, 1+offset+index)
 | 
						|
def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
 | 
						|
    x,y = GNP(N,W,H,useDivmod)
 | 
						|
    assert 0 <= x < W and 0 <= y < H
 | 
						|
    return x,y
 | 
						|
 | 
						|
def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
 | 
						|
    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
 | 
						|
 | 
						|
    Return the N-th W-tuple, where for 0 <= x_i < H."""
 | 
						|
 | 
						|
    if useLeftToRight:
 | 
						|
        elts = [None]*W
 | 
						|
        for i in range(W):
 | 
						|
            elts[i],N = getNthPairBounded(N, H)
 | 
						|
        return tuple(elts)
 | 
						|
    else:
 | 
						|
        if W==0:
 | 
						|
            return ()
 | 
						|
        elif W==1:
 | 
						|
            return (N,)
 | 
						|
        elif W==2:
 | 
						|
            return getNthPairBounded(N, H, H)
 | 
						|
        else:
 | 
						|
            LW,RW = W//2, W - (W//2)
 | 
						|
            L,R = getNthPairBounded(N, H**LW, H**RW)
 | 
						|
            return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) + 
 | 
						|
                    getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
 | 
						|
def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
 | 
						|
    t = GNT(N,W,H,useLeftToRight)
 | 
						|
    assert len(t) == W
 | 
						|
    for i in t:
 | 
						|
        assert i < H
 | 
						|
    return t
 | 
						|
 | 
						|
def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
 | 
						|
    """getNthTuple(N, maxSize, maxElement) -> x
 | 
						|
 | 
						|
    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
 | 
						|
    y < maxElement."""
 | 
						|
 | 
						|
    # All zero sized tuples are isomorphic, don't ya know.
 | 
						|
    if N == 0:
 | 
						|
        return ()
 | 
						|
    N -= 1
 | 
						|
    if maxElement is not aleph0:
 | 
						|
        if maxSize is aleph0:
 | 
						|
            raise NotImplementedError('Max element size without max size unhandled')
 | 
						|
        bounds = [maxElement**i for i in range(1, maxSize+1)]
 | 
						|
        S,M = getNthPairVariableBounds(N, bounds)
 | 
						|
    else:
 | 
						|
        S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
 | 
						|
    return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
 | 
						|
def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0, 
 | 
						|
                       useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
 | 
						|
    # FIXME: maxsize is inclusive
 | 
						|
    t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
 | 
						|
    assert len(t) <= maxSize
 | 
						|
    for i in t:
 | 
						|
        assert i < maxElement
 | 
						|
    return t
 | 
						|
 | 
						|
def getNthPairVariableBounds(N, bounds):
 | 
						|
    """getNthPairVariableBounds(N, bounds) -> (x, y)
 | 
						|
 | 
						|
    Given a finite list of bounds (which may be finite or aleph0),
 | 
						|
    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
 | 
						|
    bounds[x]."""
 | 
						|
 | 
						|
    if not bounds:
 | 
						|
        raise ValueError("Invalid bounds")
 | 
						|
    if not (0 <= N < sum(bounds)):
 | 
						|
        raise ValueError("Invalid input (out of bounds)")
 | 
						|
 | 
						|
    level = 0
 | 
						|
    active = list(range(len(bounds)))
 | 
						|
    active.sort(key=lambda i: bounds[i])
 | 
						|
    prevLevel = 0
 | 
						|
    for i,index in enumerate(active):
 | 
						|
        level = bounds[index]
 | 
						|
        W = len(active) - i
 | 
						|
        if level is aleph0:
 | 
						|
            H = aleph0
 | 
						|
        else:
 | 
						|
            H = level - prevLevel
 | 
						|
        levelSize = W*H
 | 
						|
        if N<levelSize: # Found the level
 | 
						|
            idelta,delta = getNthPairBounded(N, W, H)
 | 
						|
            return active[i+idelta],prevLevel+delta
 | 
						|
        else:
 | 
						|
            N -= levelSize
 | 
						|
            prevLevel = level
 | 
						|
    else:
 | 
						|
        raise RuntimError("Unexpected loop completion")
 | 
						|
 | 
						|
def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
 | 
						|
    x,y = GNVP(N,bounds)
 | 
						|
    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
 | 
						|
    return (x,y)
 | 
						|
 | 
						|
###
 | 
						|
 | 
						|
def testPairs():
 | 
						|
    W = 3
 | 
						|
    H = 6
 | 
						|
    a = [['  ' for x in range(10)] for y in range(10)]
 | 
						|
    b = [['  ' for x in range(10)] for y in range(10)]
 | 
						|
    for i in range(min(W*H,40)):
 | 
						|
        x,y = getNthPairBounded(i,W,H)
 | 
						|
        x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
 | 
						|
        print(i,(x,y),(x2,y2))
 | 
						|
        a[y][x] = '%2d'%i
 | 
						|
        b[y2][x2] = '%2d'%i
 | 
						|
 | 
						|
    print('-- a --')
 | 
						|
    for ln in a[::-1]:
 | 
						|
        if ''.join(ln).strip():
 | 
						|
            print('  '.join(ln))
 | 
						|
    print('-- b --')
 | 
						|
    for ln in b[::-1]:
 | 
						|
        if ''.join(ln).strip():
 | 
						|
            print('  '.join(ln))
 | 
						|
 | 
						|
def testPairsVB():
 | 
						|
    bounds = [2,2,4,aleph0,5,aleph0]
 | 
						|
    a = [['  ' for x in range(15)] for y in range(15)]
 | 
						|
    b = [['  ' for x in range(15)] for y in range(15)]
 | 
						|
    for i in range(min(sum(bounds),40)):
 | 
						|
        x,y = getNthPairVariableBounds(i, bounds)
 | 
						|
        print(i,(x,y))
 | 
						|
        a[y][x] = '%2d'%i
 | 
						|
 | 
						|
    print('-- a --')
 | 
						|
    for ln in a[::-1]:
 | 
						|
        if ''.join(ln).strip():
 | 
						|
            print('  '.join(ln))
 | 
						|
 | 
						|
###
 | 
						|
 | 
						|
# Toggle to use checked versions of enumeration routines.
 | 
						|
if False:
 | 
						|
    getNthPairVariableBounds = getNthPairVariableBoundsChecked
 | 
						|
    getNthPairBounded = getNthPairBoundedChecked
 | 
						|
    getNthNTuple = getNthNTupleChecked
 | 
						|
    getNthTuple = getNthTupleChecked
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    testPairs()
 | 
						|
 | 
						|
    testPairsVB()
 | 
						|
 |