#include #include #include void *emalloc(size_t n) { /* allocate n bytes or die trying */ void *r = malloc(n); assert(r != 0); return r; } typedef struct Expr_Node *expr; #define LIT 'c' struct Literal { int val; }; #define NEG 'n' struct Negated { expr arg; }; #define ADD '+' #define SUB '-' #define MUL '*' #define DIV '/' #define MOD '%' struct Binary { expr lhs, rhs; }; struct Expr_Node { char op; union { /* op == LIT */ struct Literal l; /* op == NEG */ struct Negated n; /* others */ struct Binary b; } u; }; #define lit_val(p) ((p)->u.l.val) #define neg_arg(p) ((p)->u.n.arg) #define bin_lhs(p) ((p)->u.b.lhs) #define bin_rhs(p) ((p)->u.b.rhs) void print(expr e) { switch (e->op) { case LIT: printf("%d", lit_val(e)); break; case NEG: printf("-"); print(neg_arg(e)); break; default: printf("("); print(bin_lhs(e)); printf("%c", e->op); print(bin_rhs(e)); printf(")"); break; } } int value(expr e) { switch (e->op) { case LIT: return lit_val(e); case NEG: return -value(neg_arg(e)); case ADD: return value(bin_lhs(e)) + value(bin_rhs(e)); case SUB: return value(bin_lhs(e)) - value(bin_rhs(e)); case MUL: return value(bin_lhs(e)) * value(bin_rhs(e)); case DIV: return value(bin_lhs(e)) / value(bin_rhs(e)); case MOD: return value(bin_lhs(e)) * value(bin_rhs(e)); default: abort(); /* shut the compiler up */ } } expr mk_literal(int v) { expr r = emalloc(sizeof *r); r->op = LIT, lit_val(r) = v; return r; } expr mk_negated(expr e) { expr r = emalloc(sizeof *r); r->op = NEG, neg_arg(r) = e; return r; } expr mk_sum(expr e, expr f) { expr r = emalloc(sizeof *r); r->op = ADD, bin_lhs(r) = e, bin_rhs(r) = f; return r; } expr mk_difference(expr e, expr f) { expr r = emalloc(sizeof *r); r->op = SUB, bin_lhs(r) = e, bin_rhs(r) = f; return r; } expr mk_product(expr e, expr f) { expr r = emalloc(sizeof *r); r->op = MUL, bin_lhs(r) = e, bin_rhs(r) = f; return r; } expr mk_quotient(expr e, expr f) { expr r = emalloc(sizeof *r); r->op = DIV, bin_lhs(r) = e, bin_rhs(r) = f; return r; } expr mk_remainder(expr e, expr f) { expr r = emalloc(sizeof *r); r->op = MOD, bin_lhs(r) = e, bin_rhs(r) = f; return r; } int main(void) { expr e = mk_product(mk_difference(mk_literal(1), mk_literal(2)), mk_sum(mk_literal(3), mk_literal(4))); print(e); printf(" == %d\n", value(e)); return 0; }