Разработка эвристики для проверки простых анонимных функций Python для эквивалентности

Я знаю, как работает сравнение функций в Python 3 (просто сравнивая адрес в памяти), и я понимаю, почему.

Я также понимаю, что "истинное" сравнение (если функции f и g возвращают тот же результат, учитывая те же аргументы, для любых аргументов?) практически невозможно.

Я ищу что-то среднее. Я хочу, чтобы сравнение работало над простейшими случаями одинаковых функций и, возможно, некоторыми менее тривиальными:

lambda x : x == lambda x : x # True
lambda x : 2 * x == lambda y : 2 * y # True
lambda x : 2 * x == lambda x : x * 2 # True or False is fine, but must be stable
lambda x : 2 * x == lambda x : x + x # True or False is fine, but must be stable

Обратите внимание, что меня интересует решение этой проблемы для анонимных функций (lambda), но не против, если решение также работает для именованных функций.

Мотивация для этого заключается в том, что внутри модуля blist было бы неплохо проверить, что два экземпляра sortedset имеют одну и ту же функцию сортировки перед выполнением объединения и т.д. на них.

Именованные функции менее интересны, потому что я могу предположить, что они будут разными, если они не идентичны. В конце концов, предположим, что кто-то создал два сортировки с именованной функцией в аргументе key. Если они предполагают, что эти экземпляры будут "совместимы" для целей заданных операций, они, вероятно, будут использовать одну и ту же функцию, а не две отдельные именованные функции, которые выполняют идентичные операции.

Я могу только думать о трех подходах. Все они кажутся трудными, поэтому любые идеи оцениваются.

  • Сравнение байткодов может работать, но может быть раздражающим, что он зависит от реализации (и, следовательно, код, который работал на одном Python, разрывается на другом).

  • Сравнение токенированного исходного кода кажется разумным и переносимым. Конечно, он менее мощный (поскольку идентичные функции, скорее всего, будут отвергнуты).

  • Твердая эвристика, заимствованная из некоторого учебника по символическому вычислению, теоретически является лучшим подходом. Это может показаться слишком тяжелым для моей цели, но на самом деле это может быть хорошим подспорьем, поскольку функции лямбда обычно крошечные, и поэтому они будут работать быстро.

ИЗМЕНИТЬ

Более сложный пример, основанный на комментарии @delnan:

# global variable
fields = ['id', 'name']

def my_function():
  global fields
  s1 = sortedset(key = lambda x : x[fields[0].lower()])
  # some intervening code here
  # ...
  s2 = sortedset(key = lambda x : x[fields[0].lower()])

Я ожидал, что ключевые функции для s1 и s2 будут оцениваться как равные?

Если промежуточный код вообще содержит любой вызов функции, значение fields может быть изменено, что приведет к различным ключевым функциям для s1 и s2. Поскольку мы явно не будем проводить анализ потока управления для решения этой проблемы, ясно, что мы должны оценивать эти две лямбда-функции как разные, если мы пытаемся выполнить эту оценку до выполнения. (Даже если fields не был глобальным, у него могло быть другое имя, связанное с ним и т.д.). Это серьезно сократило бы полезность всего этого упражнения, так как мало лямбда-функций не зависели бы от среды.

ИЗМЕНИТЬ 2:

Я понял, что очень важно сравнивать объекты функции, как они существуют во время выполнения. Без этого нельзя сравнивать все функции, зависящие от переменных из внешней области; и большинство полезных функций имеют такие зависимости. Рассматриваемые во время выполнения все функции с одной и той же сигнатурой сравнимы по чистому, логичному пути, независимо от того, на что они зависят, являются ли они нечистыми и т.д.

В результате мне нужен не только байт-код, но и глобальное состояние с момента создания функционального объекта (предположительно __globals__). Затем мне нужно сопоставить все переменные из внешней области с значениями из __globals__.

Ответ 1

Отредактировано, чтобы проверить, будет ли внешнее состояние влиять на функцию сортировки, а также, если две функции эквивалентны.


Я взломал dis.dis и друзей для вывода в глобальный файл-подобный объект. Затем я удалил номера строк и нормализованные имена переменных (не касаясь констант) и сравнив результат.

Вы можете очистить его так, чтобы dis.dis и друзья yield вышли строки, чтобы вам не пришлось ловить их вывод. Но это работающее доказательство концепции использования dis.dis для сравнения функций с минимальными изменениями.

import types
from opcode import *
_have_code = (types.MethodType, types.FunctionType, types.CodeType,
              types.ClassType, type)

def dis(x):
    """Disassemble classes, methods, functions, or code.

    With no argument, disassemble the last traceback.

    """
    if isinstance(x, types.InstanceType):
        x = x.__class__
    if hasattr(x, 'im_func'):
        x = x.im_func
    if hasattr(x, 'func_code'):
        x = x.func_code
    if hasattr(x, '__dict__'):
        items = x.__dict__.items()
        items.sort()
        for name, x1 in items:
            if isinstance(x1, _have_code):
                print >> out,  "Disassembly of %s:" % name
                try:
                    dis(x1)
                except TypeError, msg:
                    print >> out,  "Sorry:", msg
                print >> out
    elif hasattr(x, 'co_code'):
        disassemble(x)
    elif isinstance(x, str):
        disassemble_string(x)
    else:
        raise TypeError, \
              "don't know how to disassemble %s objects" % \
              type(x).__name__

def disassemble(co, lasti=-1):
    """Disassemble a code object."""
    code = co.co_code
    labels = findlabels(code)
    linestarts = dict(findlinestarts(co))
    n = len(code)
    i = 0
    extended_arg = 0
    free = None
    while i < n:
        c = code[i]
        op = ord(c)
        if i in linestarts:
            if i > 0:
                print >> out
            print >> out,  "%3d" % linestarts[i],
        else:
            print >> out,  '   ',

        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(20),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            extended_arg = 0
            i = i+2
            if op == EXTENDED_ARG:
                extended_arg = oparg*65536L
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                print >> out,  '(' + repr(co.co_consts[oparg]) + ')',
            elif op in hasname:
                print >> out,  '(' + co.co_names[oparg] + ')',
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                print >> out,  '(' + co.co_varnames[oparg] + ')',
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
            elif op in hasfree:
                if free is None:
                    free = co.co_cellvars + co.co_freevars
                print >> out,  '(' + free[oparg] + ')',
        print >> out

def disassemble_string(code, lasti=-1, varnames=None, names=None,
                       constants=None):
    labels = findlabels(code)
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(15),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                if constants:
                    print >> out,  '(' + repr(constants[oparg]) + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasname:
                if names is not None:
                    print >> out,  '(' + names[oparg] + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                if varnames:
                    print >> out,  '(' + varnames[oparg] + ')',
                else:
                    print >> out,  '(%d)' % oparg,
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
        print >> out

def findlabels(code):
    """Detect all offsets in a byte code which are jump targets.

    Return the list of offsets.

    """
    labels = []
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            label = -1
            if op in hasjrel:
                label = i+oparg
            elif op in hasjabs:
                label = oparg
            if label >= 0:
                if label not in labels:
                    labels.append(label)
    return labels

def findlinestarts(code):
    """Find the offsets in a byte code which are start of lines in the source.

    Generate pairs (offset, lineno) as described in Python/compile.c.

    """
    byte_increments = [ord(c) for c in code.co_lnotab[0::2]]
    line_increments = [ord(c) for c in code.co_lnotab[1::2]]

    lastlineno = None
    lineno = code.co_firstlineno
    addr = 0
    for byte_incr, line_incr in zip(byte_increments, line_increments):
        if byte_incr:
            if lineno != lastlineno:
                yield (addr, lineno)
                lastlineno = lineno
            addr += byte_incr
        lineno += line_incr
    if lineno != lastlineno:
        yield (addr, lineno)

class FakeFile(object):
    def __init__(self):
        self.store = []
    def write(self, data):
        self.store.append(data)

a = lambda x : x
b = lambda x : x # True
c = lambda x : 2 * x
d = lambda y : 2 * y # True
e = lambda x : 2 * x
f = lambda x : x * 2 # True or False is fine, but must be stable
g = lambda x : 2 * x
h = lambda x : x + x # True or False is fine, but must be stable

funcs = a, b, c, d, e, f, g, h

outs = []
for func in funcs:
    out = FakeFile()
    dis(func)
    outs.append(out.store)

import ast

def outfilter(out):
    for i in out:
        if i.strip().isdigit():
            continue
        if '(' in i:
            try:
                ast.literal_eval(i)
            except ValueError:
                i = "(x)"
        yield i

processed_outs = [(out, 'LOAD_GLOBAL' in out or 'LOAD_DECREF' in out)
                            for out in (''.join(outfilter(out)) for out in outs)]

for (out1, polluted1), (out2, polluted2) in zip(processed_outs[::2], processed_outs[1::2]):
    print 'Bytecode Equivalent:', out1 == out2, '\nPolluted by state:', polluted1 or polluted2

Выходной сигнал True, True, False и False и является стабильным. "Загрязненный" bool является истинным, если выход будет зависеть от внешнего состояния - либо глобального состояния, либо закрытия.

Ответ 2

Итак, сначала рассмотрим некоторые технические проблемы.

1) Байт-код: это, вероятно, не проблема, потому что вместо проверки pyc (двоичных файлов) вы можете использовать модуль dis для получения "байт-кода". например.

>>> f = lambda x, y : x+y
>>> dis.dis(f)
  1           0 LOAD_FAST                0 (x)
              3 LOAD_FAST                1 (y)
              6 BINARY_ADD          
              7 RETURN_VALUE 

Не нужно беспокоиться о платформе.

2) Обозначенный исходный код. Снова у python есть все, что вам нужно для выполнения этой работы. Вы можете использовать модуль ast для анализа кода и получения ast.

>>> a = ast.parse("f = lambda x, y : x+y")
>>> ast.dump(a)
"Module(body=[Assign(targets=[Name(id='f', ctx=Store())], value=Lambda(args=arguments(args=[Name(id='x', ctx=Param()), Name(id='y', ctx=Param())], vararg=None, kwarg=None, defaults=[]), body=BinOp(left=Name(id='x', ctx=Load()), op=Add(), right=Name(id='y', ctx=Load()))))])"

Итак, вопрос, который мы должны решить, заключается в следующем: возможно ли определить, что две функции эквивалентны аналитически?

Человеку легко сказать, что 2*x равно x+x, но как мы можем создать алгоритм для его доказательства?

Если это то, чего вы хотите достичь, вы можете проверить это: http://en.wikipedia.org/wiki/Computer-assisted_proof

Однако, если в конечном итоге вы просто хотите утверждать, что два разных набора данных отсортированы в одном порядке, вам просто нужно запустить функцию сортировки A в наборе данных B и наоборот, а затем проверить результат. Если они идентичны, то функции, вероятно, функционально идентичны. Конечно, проверка действительна только для указанных наборов данных.