
/*  File   : terms.c
    Author : Richard A. O'Keefe.
    Updated: 2010
    Purpose: Terms and Unification in C
#include <ctype.h>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

/*  A term is
    - a VARIABLE
      + which may be UNBOUND
      + or BOUND to some term
    - or a NON-VARIABLE, which consists of
      a FUNCTION SYMBOL and a sequence of
      zero or more ARGUMENTS, which are terms.
      The number of arguments is called the ARITY.
      The combination of a function symbol and an arity,
      commonly written f/n, is called a FUNCTOR.

    A function symbol is an uninterpreted atom.

    Programming languages based on first order logic commonly
    admit numbers as a kind of non-variable term, in which case
    we distinguish between numbers and callable terms.
    An ATOMIC term is either a number or a function symbol with
    no arguments; a COMPOUND term is a term with 1 or more

      representing terms
      classifying terms
      dereferencing terms (returning something that is not a bound variable)
      testing whether an unbound variable occurs in a term
      reading terms
      writing terms
      testing whether two terms are identical
      binding (and trailing) a variable
      unifying two terms.

        : variable
        | atom
        | atom '(' term_sequence ')'
        | list
        : term [ ',' term_sequence ]
        : '[' [term_sequence ['|' term] ']'
        : /[_A-Z][A-Za-z0-9_]* /
        : /[a-z][A-Za-z0-9_]* /
        | /'([^']|'')*'/

    Lists are just a handy abbreviation.
    [] is the same as '[]'.
    [t1,t2,...,tn] is the same as [t1,t2,...,tn|[]]
    [t1|tt] is the same as '.'(t1,tt)
    [t1,t2,...,tn|tt] is the same as '.'(t1,[t2,...,tn|tt])


/* ---------------------------------------------------------------------

    The string hash table holds a unique copy of each string.

typedef struct Unique_String *ustr;
typedef struct Term_Info *term;

struct Unique_String {
    ustr            next;
    term            variable;   /* only used while reading */
    term            atom;       /* also used when reading */
    size_t          quotes;     /* how many extra ' does it need */
    size_t          size;
    char            name[4];

#define SEGBITS 9
#define SEGSIZE (1 << SEGBITS)
typedef ustr segment[SEGSIZE];
#define MAXLOAD 4
#define DIRBITS 9
#define DIRSIZE (1 << DIRBITS)
#define INIT_SEGS 8
#error oops

struct HASHTABLE {
    unsigned long p;            /* next bucket to split */
    unsigned long maxp;         /* upper bound on p during expansion */
    unsigned long max_load;     /* maximum load factor (in keys): */
                                /* n_segs*SEGSIZE*MAXLOAD */
    unsigned long n_keys;       /* number of keys */
    unsigned long n_segs;       /* number of segments */
    segment      *directory[DIRSIZE];

static struct HASHTABLE string_table;

static void *emalloc(size_t size) {
    void *result = malloc(size);
    if (result == 0) {

void init_strtab(void) {
    struct HASHTABLE * const h = &string_table;
    int i, j;

    h->p = 0;
    h->n_segs = INIT_SEGS;
    h->maxp = h->n_segs << SEGBITS;
    h->max_load = MAXLOAD * h->maxp;
    h->n_keys = 0;

    for (i = 0; i < (int)h->n_segs; i++) {
        segment * const s = emalloc(sizeof *s);

        for (j = 0; j < SEGSIZE; j++) (*s)[j] = 0;
        h->directory[i] = s;
    for (i = i; i < DIRSIZE; i++) h->directory[i] = 0;

static unsigned long hash_string(char const *str, size_t len) {
    unsigned char const *p = (unsigned char const *)str;
    unsigned long h;
    size_t L;

    h = 0;
    for (L = 3; L < len; L += 4, p += 4)
        h = h * 31 + (p[0] + (p[1] << 3) + (p[2] << 6) + (p[3] << 9));
    switch (L - len) {
      case 0:
        h += p[2] << 6;         /*FALLTHROUGH*/
      case 1:
        h += p[1] << 3;         /*FALLTHROUGH*/
      case 2:
        h += p[0];              /*FALLTHROUGH*/
    return h;

static size_t count_quotes(char const *str, size_t len) {
    size_t i, q;

    if (len != 0 && islower(str[0])) {
        i = 1;
        while (i != len && (isalnum(str[i]) || str[i] == '_')) i++;
        if (i == len) return 0;
    q = 2;
    for (i = 0; i != len; i++) if (str[i] == '\'') q++;
    return q;

ustr find_mem(char const *str, size_t len) {
    struct HASHTABLE * const h = &string_table;
    size_t const pad = ((0u-len) & ((sizeof (unsigned long)) - 1)) + 1;
    unsigned long const u0 = hash_string(str, len);

        unsigned long const u1 = u0 & (h->maxp - 1);
        unsigned long const u  = u1 >= h->p ? u1 : u0 & ((h->maxp << 1) - 1);
        ustr * const b = &(*h->directory[u>>SEGBITS])[u & (SEGSIZE - 1)];
        ustr *p;        /* previous */
        ustr  c;        /* current  */

        for (p = b; (c = *p) != 0; p = &c->next) {
            if (c->size == len && 0 == memcmp(c->name, str, len)) {
                *p = c->next;
                c->next = *b;
                *b = c;
                return c;
        c = emalloc(offsetof(struct Unique_String, name) + len + pad);
        c->quotes = count_quotes(str, len);
        c->size = len;
        (void)memcpy(c->name, str, len);
        c->name[len] = '\0';
        c->next = *b;
        *b = c;
        if (++h->n_keys > h->max_load) {
            unsigned long const oldp = h->p;
            unsigned long const newp = oldp + h->maxp;

            if (newp < DIRSIZE * (unsigned long)SEGSIZE) {
                segment * const olds = h->directory[h->p >> SEGBITS];
                segment * news;

                if ((newp & (SEGSIZE-1)) == 0) {
                    int j;

                    news = emalloc(sizeof *news);
                    for (j = 0; j < SEGSIZE; j++) (*news)[j] = 0;
                    h->directory[newp >> SEGBITS] = news;
                } else {
                    news = h->directory[newp >> SEGBITS];
                h->max_load += MAXLOAD;
                    unsigned long const mask = (h->maxp << 1) - 1;
                    ustr *prev = &(*olds)[oldp & (SEGSIZE-1)];
                    ustr *tail = &(*news)[newp & (SEGSIZE-1)];
                    ustr  curr;

                    while ((curr = *prev) != 0) {
                        unsigned long const slot =
                            mask & hash_string(curr->name, curr->size);

                        if (slot == newp) {
                            *tail = curr;
                            tail = &curr->next;
                            *prev = curr->next;
                        } else {
                            prev = &curr->next;
                    *tail = 0;
                if (++h->p == h->maxp) {
                    h->maxp <<= 1;
                    h->p = 0;
        return c;

    We'll represent a term by a pointer to a record containing
       - the arity, using -1 for an unbound variable or -2 for a bound one
       - if the term is a variable, a pointer for its value
       - if the term is not a variable, a pointer to a unique string
         for its function symbol, and 0 or more pointers for its

    We're not going to bother using hash consing for these terms.
    Function symbols do not change at run time, but terms do.
struct Term_Info {
    int arity;
    union {
        term binding;
        ustr name;
    } u;
    term arg[];

#define BOUND_VARIABLE (-2)

term make_anonymous_variable(void) {
    term r = emalloc(offsetof(struct Term_Info, arg[0]));
    r->arity = UNBOUND_VARIABLE;
    r->u.binding = r;
    return r;

term make_variable(ustr name) {
    term r = name->variable;
    if (r == 0) {
        r = emalloc(offsetof(struct Term_Info, arg[0]));
        r->arity = UNBOUND_VARIABLE;
        r->u.binding = r;
        name->variable = r;
    return r;

/* just for efficiency, we use hash consing for atoms */
term make_atom(ustr name) {
    term r = name->atom;
    if (r == 0) {
        r = emalloc(offsetof(struct Term_Info, arg[0]));
        r->arity = 0;
        r->u.name = name;
        name->atom = r;
    return r;

term make_term(ustr name, int arity, term const arg[]) {
    term r = emalloc(offsetof(struct Term_Info, arg[arity]));
    int i;
    r->arity = arity,
    r->u.name = name;
    for (i = 0; i < arity; i++) r->arg[i] = arg[i];
    return r;



static int more_input(void) {
    int c;

    do {
        c = getchar();
        if (c == '%') {
            do c = getchar(); while (c >= 0 && c != '\n');
        if (c < 0) return 0;
    } while (c <= ' ');
    ungetc(c, stdin);
    return 1;

static int next_non_blank(void) {
    int c;

    do {
        c = getchar();
        if (c == '%') {
            do c = getchar(); while (c >= 0 && c != '\n');
        if (c < 0) {
            fprintf(stderr, "Unexpected end of file\n");
    } while (c <= ' ');
    return c;

static void name_buffer_overflow(void) {
    fprintf(stderr, "Variable name or atom too long.\n");

static void term_buffer_overflow(void) {
    fprintf(stderr, "Term or list too long.\n");

static term nil_atom = 0;
static ustr dot_name = 0;

term read_term(void) {
    int c;
    char name[1024];
    char *p, *e;
    term args[100];
    term *t, *x;

    c = next_non_blank();
    p = name;
    e = name + sizeof name;

    if (isupper(c) || c == '_') {
        do {
            if (p == e) name_buffer_overflow();
            *p++ = c;
            c = getchar();
        } while (isalnum(c) || c == '_');
        ungetc(c, stdin);
        if (p == name+1 && name[0] == '_') {
            return make_anonymous_variable();
        } else {
            return make_variable(find_mem(name, p-name));

    if (c == '[') {
        if (nil_atom == 0) nil_atom = make_atom(find_mem("[]", 2));
        c = next_non_blank();
        if (c == ']') {
            return nil_atom;
        } else {
            term r, n;
            t = args;
            x = args + sizeof args;
            ungetc(c, stdin);
            do {
                if (t == x) term_buffer_overflow();
                *t++ = read_term();
                c = next_non_blank();
            } while (c == ',');
            if (c == '|') {
                r = read_term();
                c = next_non_blank();
            } else {
                r = nil_atom;
            if (c != ']') {
                fprintf(stderr, "Missing ]\n");
            if (dot_name == 0) dot_name = find_mem(".", 1);
            while (t != args) {
                n = emalloc(offsetof(struct Term_Info, arg[2]));
                n->arity = 2, n->u.name = dot_name,
                n->arg[0] = *--t, n->arg[1] = r;
                r = n;
            return r;

    if (islower(c)) {
        do {
            if (p == e) name_buffer_overflow();
            *p++ = c;
            c = getchar();
        } while (isalnum(c) || c == '_');
    } else
    if (c == '\'') {
        for (;;) {
            c = getchar();
            if (c < 0) {
                fprintf(stderr, "Unexpected EOF\n");
            if (c == '\'') {
                c = getchar();
                if (c != '\'') break;
            if (p == e) name_buffer_overflow();
            *p++ = c;
    if (c != '(') {
        ungetc(c, stdin);
        return make_atom(find_mem(name, p-name));
    } else {
        t = args;
        x = args + sizeof args;
        do {
            if (t == x) term_buffer_overflow();
            *t++ = read_term();
            c = next_non_blank();
        } while (c == ',');
        if (c != ')') {
            fprintf(stderr, "Missing )\n");
        return make_term(find_mem(name, p-name), t-args, args);

    fprintf(stderr, "A term cannot begin with '%c'\n");


    Classifying and decomposing terms


#define is_bound_variable(t)   ((t)->arity == BOUND_VARIABLE)
#define is_unbound_variable(t) ((t)->arity == UNBOUND_VARIABLE)
#define is_variable(t)         ((t)->arity < 0)
#define is_not_variable(t)     ((t)->arity >= 0)
#define is_atom(t)             ((t)->arity == 0)
#define is_compound(t)         ((t)->arity > 0)
#define is_dotted_pair(t)      ((t)->arity == 2 && (t)->u.name == dot_name)
#define variable_binding(t)    ((t)->u.binding)

term dereference(term t) {
    while (is_bound_variable(t)) t = variable_binding(t);
    return t;

term argument(term t, int i) {
    t = dereference(t);
    if (i < 0 || i >= t->arity) {
        fprintf(stderr, "argument index out of range\n");
    return t->arg[i];




void print_term(term t) {
    t = dereference(t);
    if (is_unbound_variable(t)) {
        printf("V%p", t);
    } else {
        char  *p = t->u.name->name;
        size_t n = t->u.name->size;
        size_t i;

        if (t->u.name->quotes == 0) {
            for (i = 0; i != n; i++) putchar(p[i]);
        } else {
            for (i = 0; i != n; i++) {
                if (p[i] == '\'') putchar('\'');
        if (is_compound(t)) {
            int a;

            for (a = 0; a < t->arity; a++) {
                printf(a == 0 ? "(" : ", ");

#if 0

int main(void) {
    while (more_input()) {



    Testing whether two terms are the same

    Two terms are the same if, after replacing bound variables
    by their values, either
        - both are the same unbound variable or
        - both have the same arity and function symbol,
          and corresponding arguments are the same.

int same(term x, term y) {
    int i;

    x = dereference(x);
    y = dereference(y);
    if (is_unbound_variable(x)) return x == y;
    if (is_unbound_variable(y)) return 0;
    if (x->arity != y->arity) return 0;
    if (x->u.name != y->u.name) return 0;
    for (i = 0; i < x->arity; i++) {
        if (!same(x->arg[i], y->arg[i])) return 0;
    return 1;


    Checking whether a term contains a variable.
    We don't want X = f(X) to succeed making X = f(f(f(f(f(f....
    This check is called the "occurs check".
    For efficiency, most Prolog systems omit it.
    Alain Colmerauer showed that this is sort of legitimate
    as long as you admit to using a non-standard version of

int contains(term t, term v) {
    t = dereference(t);
    if (is_unbound_variable(t)) {
        return t == v;
    } else {
        int i;
        for (i = 0; i < t->arity; i++) {
            if (contains(t->arg[i], v)) return 1;
        return 0;


    In the unification process, we try to make two terms the same
    by binding variables to values.  This is a side effect!
    If the match fails, we want to undo it.
    Indeed, in the context of a full system, we might be backtracking
    over alternative proofs, and would need to undo successful matches.
    So we maintain a "trail".

    In general, a trail is a history of inverse actions.
    Before performing a side effect E, you push -E on the trail,
    where -E is the action that will undo E.
    In our case, the only action is "bind a variable", so if we
    know what the variable is, we know everything we need.

    With hindsight, I would have done better to thread the trail
    through the variable terms themselves.  I was too heavily
    influenced by Prolog, which uses some cleverness to avoid
    trailing bindings in the first place whenever they are sure
    to disappear before we need to undo them.

term *trail_base = 0;
int   trail_top  = 0;
int   trail_limit = 0;

static void bind(term v, term t) {
    if (trail_top == trail_limit) {
        int   new_limit = trail_top == 0 ? 1000 : trail_limit*2;
        term *new_trail = emalloc(new_limit * sizeof *trail_base);
        memcpy(new_trail, trail_base, trail_top * sizeof *trail_base);
        trail_base  = new_trail;
        trail_limit = new_limit;
    trail_base[trail_top++] = v;
    v->u.binding = t;
    v->arity = BOUND_VARIABLE;

static void unwind_to(int old_top) {
    while (trail_top != old_top) {
        term v = trail_base[--trail_top];
        v->u.binding = v;
        v->arity = UNBOUND_VARIABLE;


    Now for the big one:

        U N I F I C A T I O N !

    After all that preparation, it's simple.

    To unify two terms x and y:
        dereference them.
        if x is an unbound variable
            if y is x do nothing and succeed.
            if y is not x and y contains x, fail.
            otherwise bind x to y and succeed.
        if y is an unbound variable
            the same only the other way around.
        if x and y have different arities or function symbols, fail.
        otherwise, unify corresponding arguments,
        and fail if any of the recursive calls fails.

static int unify_worker(term x, term y) {
    int i;

    x = dereference(x);
    y = dereference(y);
    if (is_unbound_variable(x)) {
        if (y == x) return 1;
        if (contains(y, x)) return 0;
        bind(x, y);
    } else
    if (is_unbound_variable(y)) {
        if (x == y) return 1;
        if (contains(x, y)) return 0;
        bind(y, x);
    } else {
        if (x->arity != y->arity) return 0;
        if (x->u.name != y->u.name) return 0;
        for (i = 0; i < x->arity; i++) {
            if (!unify_worker(x->arg[i], y->arg[i])) return 0;
    return 1;

int unify(term x, term y) {
    int old_trail = trail_top;

    if (unify_worker(x, y)) return 1;
    return 0;




int main(void) {
    term x, y;
    int old_trail;
    int c;

    while (more_input()) {
        x = read_term();
        if (next_non_blank() != '=') {
            fprintf(stderr, "Missing =\n");
        y = read_term();
        if (next_non_blank() != '.') {
            fprintf(stderr, "Missing .\n");
        old_trail = trail_top;
        if (unify_worker(x, y)) {
            printf("YES: ");
        } else {
    return 0;