mirror of https://github.com/rui314/9cc.git
260 lines
6.3 KiB
C
260 lines
6.3 KiB
C
#include "9cc.h"
|
|
|
|
// Semantics analyzer. This pass plays a few important roles as shown
|
|
// below:
|
|
//
|
|
// - Add types to nodes. For example, a tree that represents "1+2" is
|
|
// typed as INT because the result type of an addition of two
|
|
// integers is integer.
|
|
//
|
|
// - Insert nodes to make array-to-pointer conversion explicit.
|
|
// Recall that, in C, "array of T" is automatically converted to
|
|
// "pointer to T" in most contexts.
|
|
//
|
|
// - Insert nodes for implicit cast so that they are explicitly
|
|
// represented in AST.
|
|
//
|
|
// - Scales operands for pointer arithmetic. E.g. ptr+1 becomes ptr+4
|
|
// for integer and becomes ptr+8 for pointer.
|
|
//
|
|
// - Reject bad assignments, such as `1=2+3`.
|
|
|
|
static Node *maybe_decay(Node *base, bool decay) {
|
|
if (!decay || base->ty->ty != ARY)
|
|
return base;
|
|
|
|
Node *node = calloc(1, sizeof(Node));
|
|
node->op = ND_ADDR;
|
|
node->ty = ptr_to(base->ty->ary_of);
|
|
node->expr = base;
|
|
node->token = base->token;
|
|
return node;
|
|
}
|
|
|
|
noreturn static void bad_node(Node *node, char *msg) {
|
|
bad_token(node->token, msg);
|
|
}
|
|
|
|
static void check_lval(Node *node) {
|
|
int op = node->op;
|
|
if (op != ND_VARREF && op != ND_DEREF && op != ND_DOT)
|
|
bad_node(node, "not an lvalue");
|
|
}
|
|
|
|
static Node *scale_ptr(int op, Node *base, Type *ty) {
|
|
Node *node = calloc(1, sizeof(Node));
|
|
node->op = op;
|
|
node->lhs = base;
|
|
node->rhs = new_int_node(ty->ptr_to->size, base->token);
|
|
node->token = base->token;
|
|
return node;
|
|
}
|
|
|
|
static Node *cast(Node *base, Type *ty) {
|
|
Node *node = calloc(1, sizeof(Node));
|
|
node->op = ND_CAST;
|
|
node->ty = ty;
|
|
node->expr = base;
|
|
node->token = base->token;
|
|
return node;
|
|
}
|
|
|
|
static void check_int(Node *node) {
|
|
int t = node->ty->ty;
|
|
if (t != INT && t != CHAR && t != BOOL)
|
|
bad_node(node, "not an integer");
|
|
}
|
|
|
|
static Node *do_walk(Node *node, bool decay);
|
|
|
|
static Node *walk(Node *node) {
|
|
return do_walk(node, true);
|
|
}
|
|
|
|
static Node *walk_nodecay(Node *node) {
|
|
return do_walk(node, false);
|
|
}
|
|
|
|
static Node *do_walk(Node *node, bool decay) {
|
|
switch (node->op) {
|
|
case ND_NUM:
|
|
case ND_NULL:
|
|
case ND_BREAK:
|
|
case ND_CONTINUE:
|
|
return node;
|
|
case ND_VARREF:
|
|
return maybe_decay(node, decay);
|
|
case ND_IF:
|
|
node->cond = walk(node->cond);
|
|
node->then = walk(node->then);
|
|
if (node->els)
|
|
node->els = walk(node->els);
|
|
return node;
|
|
case ND_FOR:
|
|
if (node->init)
|
|
node->init = walk(node->init);
|
|
if (node->cond)
|
|
node->cond = walk(node->cond);
|
|
if (node->inc)
|
|
node->inc = walk(node->inc);
|
|
node->body = walk(node->body);
|
|
return node;
|
|
case ND_DO_WHILE:
|
|
case ND_SWITCH:
|
|
node->cond = walk(node->cond);
|
|
node->body = walk(node->body);
|
|
return node;
|
|
case ND_CASE:
|
|
node->body = walk(node->body);
|
|
return node;
|
|
case '+':
|
|
node->lhs = walk(node->lhs);
|
|
node->rhs = walk(node->rhs);
|
|
|
|
if (node->rhs->ty->ty == PTR) {
|
|
Node *n = node->lhs;
|
|
node->lhs = node->rhs;
|
|
node->rhs = n;
|
|
}
|
|
check_int(node->rhs);
|
|
|
|
if (node->lhs->ty->ty == PTR) {
|
|
node->rhs = scale_ptr('*', node->rhs, node->lhs->ty);
|
|
node->ty = node->lhs->ty;
|
|
} else {
|
|
node->ty = int_ty();
|
|
}
|
|
return node;
|
|
case '-': {
|
|
node->lhs = walk(node->lhs);
|
|
node->rhs = walk(node->rhs);
|
|
|
|
Type *lty = node->lhs->ty;
|
|
Type *rty = node->rhs->ty;
|
|
|
|
if (lty->ty == PTR && rty->ty == PTR) {
|
|
if (!same_type(rty, lty))
|
|
bad_node(node, "incompatible pointer");
|
|
node = scale_ptr('/', node, lty);
|
|
node->ty = lty;
|
|
} else {
|
|
node->ty = int_ty();
|
|
}
|
|
return node;
|
|
}
|
|
case '=':
|
|
node->lhs = walk_nodecay(node->lhs);
|
|
check_lval(node->lhs);
|
|
node->rhs = walk(node->rhs);
|
|
if (node->lhs->ty->ty == BOOL)
|
|
node->rhs = cast(node->rhs, bool_ty());
|
|
node->ty = node->lhs->ty;
|
|
return node;
|
|
case ND_DOT: {
|
|
node->expr = walk(node->expr);
|
|
if (node->expr->ty->ty != STRUCT)
|
|
bad_node(node, "struct expected before '.'");
|
|
|
|
Type *ty = node->expr->ty;
|
|
if (!ty->members)
|
|
bad_node(node, "incomplete type");
|
|
|
|
node->ty = map_get(ty->members, node->name);
|
|
if (!node->ty)
|
|
bad_node(node, format("member missing: %s", node->name));
|
|
return maybe_decay(node, decay);
|
|
}
|
|
case '?':
|
|
node->cond = walk(node->cond);
|
|
node->then = walk(node->then);
|
|
node->els = walk(node->els);
|
|
node->ty = node->then->ty;
|
|
return node;
|
|
case '*':
|
|
case '/':
|
|
case '%':
|
|
case '<':
|
|
case '|':
|
|
case '^':
|
|
case '&':
|
|
case ND_EQ:
|
|
case ND_NE:
|
|
case ND_LE:
|
|
case ND_SHL:
|
|
case ND_SHR:
|
|
case ND_LOGAND:
|
|
case ND_LOGOR:
|
|
node->lhs = walk(node->lhs);
|
|
node->rhs = walk(node->rhs);
|
|
check_int(node->lhs);
|
|
check_int(node->rhs);
|
|
node->ty = int_ty();
|
|
return node;
|
|
case ',':
|
|
node->lhs = walk(node->lhs);
|
|
node->rhs = walk(node->rhs);
|
|
node->ty = node->rhs->ty;
|
|
return node;
|
|
case '!':
|
|
case '~':
|
|
node->expr = walk(node->expr);
|
|
check_int(node->expr);
|
|
node->ty = int_ty();
|
|
return node;
|
|
case ND_ADDR:
|
|
node->expr = walk(node->expr);
|
|
check_lval(node->expr);
|
|
node->ty = ptr_to(node->expr->ty);
|
|
if (node->expr->op == ND_VARREF)
|
|
node->expr->var->address_taken = true;
|
|
return node;
|
|
case ND_DEREF:
|
|
node->expr = walk(node->expr);
|
|
|
|
if (node->expr->ty->ty != PTR)
|
|
bad_node(node, "operand must be a pointer");
|
|
|
|
if (node->expr->ty->ptr_to->ty == VOID)
|
|
bad_node(node, "cannot dereference void pointer");
|
|
|
|
node->ty = node->expr->ty->ptr_to;
|
|
return maybe_decay(node, decay);
|
|
case ND_RETURN:
|
|
case ND_EXPR_STMT:
|
|
node->expr = walk(node->expr);
|
|
return node;
|
|
case ND_CALL:
|
|
for (int i = 0; i < node->args->len; i++)
|
|
node->args->data[i] = walk(node->args->data[i]);
|
|
node->ty = node->ty->returning;
|
|
return node;
|
|
case ND_COMP_STMT: {
|
|
for (int i = 0; i < node->stmts->len; i++)
|
|
node->stmts->data[i] = walk(node->stmts->data[i]);
|
|
return node;
|
|
}
|
|
case ND_STMT_EXPR: {
|
|
for (int i = 0; i < node->stmts->len; i++)
|
|
node->stmts->data[i] = walk(node->stmts->data[i]);
|
|
node->expr = walk(node->expr);
|
|
node->ty = node->expr->ty;
|
|
return node;
|
|
}
|
|
default:
|
|
assert(0 && "unknown node type");
|
|
}
|
|
}
|
|
|
|
Type *get_type(Node *node) {
|
|
return walk_nodecay(node)->ty;
|
|
}
|
|
|
|
void sema(Program *prog) {
|
|
for (int i = 0; i < prog->funcs->len; i++) {
|
|
Function *fn = prog->funcs->data[i];
|
|
Node *node = fn->node;
|
|
assert(node->op == ND_FUNC);
|
|
node->body = walk(node->body);
|
|
}
|
|
}
|