934 lines
33 KiB
C++
934 lines
33 KiB
C++
#include "codegen.h"
|
|
#include <iostream>
|
|
|
|
CodeGen::CodeGen() {
|
|
context = std::make_unique<llvm::LLVMContext>();
|
|
module = std::make_unique<llvm::Module>("sun_module", *context);
|
|
builder = std::make_unique<llvm::IRBuilder<>>(*context);
|
|
|
|
// Declare runtime functions
|
|
// void print_int(int)
|
|
llvm::FunctionType* printIntType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::Type::getInt32Ty(*context)}, false);
|
|
llvm::Function::Create(printIntType, llvm::Function::ExternalLinkage, "print_int", module.get());
|
|
|
|
// void print_string(char*)
|
|
llvm::FunctionType* printStrType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
|
|
llvm::Function::Create(printStrType, llvm::Function::ExternalLinkage, "print_string", module.get());
|
|
|
|
// char* int_to_str(int)
|
|
llvm::FunctionType* intToStrType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false);
|
|
llvm::Function::Create(intToStrType, llvm::Function::ExternalLinkage, "int_to_str", module.get());
|
|
|
|
// char* str_concat(char*, char*)
|
|
llvm::FunctionType* strConcatType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
|
|
llvm::Function::Create(strConcatType, llvm::Function::ExternalLinkage, "str_concat", module.get());
|
|
|
|
// void* sun_array_create(int size)
|
|
llvm::FunctionType* arrayCreateType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false);
|
|
llvm::Function::Create(arrayCreateType, llvm::Function::ExternalLinkage, "sun_array_create", module.get());
|
|
|
|
// void sun_array_set(void* arr, int index, void* value)
|
|
llvm::FunctionType* arraySetType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
|
|
llvm::Function::Create(arraySetType, llvm::Function::ExternalLinkage, "sun_array_set", module.get());
|
|
|
|
// void* sun_array_get(void* arr, int index)
|
|
llvm::FunctionType* arrayGetType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context)}, false);
|
|
llvm::Function::Create(arrayGetType, llvm::Function::ExternalLinkage, "sun_array_get", module.get());
|
|
|
|
// int sun_array_length(void* arr)
|
|
llvm::FunctionType* arrayLenType = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
|
|
llvm::Function::Create(arrayLenType, llvm::Function::ExternalLinkage, "sun_array_length", module.get());
|
|
|
|
// void* sun_read_csv(char* filename, char* separator, char* quote)
|
|
llvm::FunctionType* readCsvType = llvm::FunctionType::get(
|
|
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0),
|
|
{
|
|
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), // filename
|
|
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), // separator
|
|
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0) // quote
|
|
},
|
|
false
|
|
);
|
|
llvm::Function::Create(readCsvType, llvm::Function::ExternalLinkage, "sun_read_csv", module.get());
|
|
}
|
|
|
|
llvm::AllocaInst* CodeGen::createEntryBlockAlloca(llvm::Function* theFunction, const std::string& varName, llvm::Type* type) {
|
|
llvm::IRBuilder<> tmpBuilder(&theFunction->getEntryBlock(), theFunction->getEntryBlock().begin());
|
|
return tmpBuilder.CreateAlloca(type, nullptr, varName);
|
|
}
|
|
|
|
void CodeGen::generate(ASTNode& node) {
|
|
if (dynamic_cast<FunctionDef*>(&node) || dynamic_cast<ClassDef*>(&node)) {
|
|
if (mainFunction && builder->GetInsertBlock()) {
|
|
mainInsertBlock = builder->GetInsertBlock();
|
|
}
|
|
node.accept(*this);
|
|
} else {
|
|
if (!mainFunction) {
|
|
llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), false);
|
|
mainFunction = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "main", module.get());
|
|
llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", mainFunction);
|
|
builder->SetInsertPoint(bb);
|
|
} else {
|
|
if (mainInsertBlock) {
|
|
builder->SetInsertPoint(mainInsertBlock);
|
|
}
|
|
}
|
|
node.accept(*this);
|
|
mainInsertBlock = builder->GetInsertBlock();
|
|
}
|
|
}
|
|
|
|
void CodeGen::finish() {
|
|
if (mainFunction) {
|
|
if (mainInsertBlock && !mainInsertBlock->getTerminator()) {
|
|
builder->SetInsertPoint(mainInsertBlock);
|
|
builder->CreateRet(llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)));
|
|
}
|
|
}
|
|
}
|
|
|
|
void CodeGen::print() {
|
|
module->print(llvm::outs(), nullptr);
|
|
}
|
|
|
|
void CodeGen::visit(NumberExpr& node) {
|
|
lastValue = llvm::ConstantInt::get(*context, llvm::APInt(32, node.value, true));
|
|
lastClassName = ""; // Not a class
|
|
}
|
|
|
|
void CodeGen::visit(StringExpr& node) {
|
|
lastValue = builder->CreateGlobalString(node.value);
|
|
lastClassName = "String"; // Special marker for string
|
|
}
|
|
|
|
void CodeGen::visit(VariableExpr& node) {
|
|
if (namedValues.find(node.name) == namedValues.end()) {
|
|
std::cerr << "Unknown variable name: " << node.name << std::endl;
|
|
lastValue = nullptr;
|
|
lastClassName = "";
|
|
return;
|
|
}
|
|
// Load the value from the alloca
|
|
llvm::AllocaInst* alloca = namedValues[node.name];
|
|
lastValue = builder->CreateLoad(alloca->getAllocatedType(), alloca, node.name.c_str());
|
|
|
|
if (varTypes.find(node.name) != varTypes.end()) {
|
|
lastClassName = varTypes[node.name];
|
|
} else {
|
|
lastClassName = "";
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(AssignExpr& node) {
|
|
node.value->accept(*this);
|
|
llvm::Value* val = lastValue;
|
|
std::string valClassName = lastClassName; // Capture type of value
|
|
|
|
if (!val) return;
|
|
|
|
if (namedValues.find(node.name) == namedValues.end()) {
|
|
std::cerr << "Unknown variable name: " << node.name << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
llvm::AllocaInst* alloca = namedValues[node.name];
|
|
builder->CreateStore(val, alloca);
|
|
lastValue = val;
|
|
|
|
// Update type tracking
|
|
if (!valClassName.empty()) {
|
|
varTypes[node.name] = valClassName;
|
|
lastClassName = valClassName;
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(CallExpr& node) {
|
|
// Handle "readcsv" specially
|
|
if (node.callee == "readcsv") {
|
|
if (node.args.size() < 1 || node.args.size() > 3) {
|
|
std::cerr << "readcsv() takes 1 to 3 arguments." << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
node.args[0]->accept(*this);
|
|
llvm::Value* filenameVal = lastValue;
|
|
|
|
llvm::Value* sepVal;
|
|
if (node.args.size() >= 2) {
|
|
node.args[1]->accept(*this);
|
|
sepVal = lastValue;
|
|
} else {
|
|
sepVal = builder->CreateGlobalStringPtr(",");
|
|
}
|
|
|
|
llvm::Value* quoteVal;
|
|
if (node.args.size() >= 3) {
|
|
node.args[2]->accept(*this);
|
|
quoteVal = lastValue;
|
|
} else {
|
|
quoteVal = builder->CreateGlobalStringPtr("\"");
|
|
}
|
|
|
|
llvm::Function* readCsvFn = module->getFunction("sun_read_csv");
|
|
lastValue = builder->CreateCall(readCsvFn, {filenameVal, sepVal, quoteVal}, "csv_data");
|
|
lastClassName = "CSVArray";
|
|
return;
|
|
}
|
|
|
|
// Handle "len" specially
|
|
if (node.callee == "len") {
|
|
if (node.args.size() != 1) {
|
|
std::cerr << "len() takes exactly 1 argument." << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
node.args[0]->accept(*this);
|
|
llvm::Value* arrVal = lastValue;
|
|
|
|
// Cast to i8* if needed (it should be i8* already as void*)
|
|
if (arrVal->getType() != llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)) {
|
|
arrVal = builder->CreateBitCast(arrVal, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
|
|
}
|
|
|
|
llvm::Function* lenFn = module->getFunction("sun_array_length");
|
|
lastValue = builder->CreateCall(lenFn, {arrVal}, "len");
|
|
lastClassName = ""; // Returns int
|
|
return;
|
|
}
|
|
|
|
// Handle "print" specially or as a normal function
|
|
if (node.callee == "print") {
|
|
if (node.args.size() != 1) {
|
|
std::cerr << "print() takes exactly 1 argument." << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
node.args[0]->accept(*this);
|
|
llvm::Value* argVal = lastValue;
|
|
std::string argType = lastClassName;
|
|
|
|
if (argType == "String") {
|
|
llvm::Function* printStr = module->getFunction("print_string");
|
|
builder->CreateCall(printStr, {argVal});
|
|
} else {
|
|
// Assume int
|
|
llvm::Function* printInt = module->getFunction("print_int");
|
|
builder->CreateCall(printInt, {argVal});
|
|
}
|
|
lastValue = nullptr; // print returns void
|
|
return;
|
|
}
|
|
|
|
llvm::Function* calleeF = module->getFunction(node.callee);
|
|
if (!calleeF) {
|
|
std::cerr << "Unknown function referenced: " << node.callee << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
if (calleeF->arg_size() != node.args.size()) {
|
|
std::cerr << "Incorrect # arguments passed." << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
std::vector<llvm::Value*> argsV;
|
|
for (unsigned i = 0, e = node.args.size(); i != e; ++i) {
|
|
node.args[i]->accept(*this);
|
|
if (!lastValue) return;
|
|
argsV.push_back(lastValue);
|
|
}
|
|
|
|
lastValue = builder->CreateCall(calleeF, argsV, "calltmp");
|
|
// We don't know the return type class name easily without function signatures in symbol table.
|
|
// For now, assume int.
|
|
lastClassName = "";
|
|
}
|
|
|
|
void CodeGen::visit(BinaryExpr& node) {
|
|
node.left->accept(*this);
|
|
llvm::Value* L = lastValue;
|
|
std::string typeL = lastClassName;
|
|
|
|
node.right->accept(*this);
|
|
llvm::Value* R = lastValue;
|
|
std::string typeR = lastClassName;
|
|
|
|
if (!L || !R) {
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
// String Concatenation
|
|
if (node.op == "+") {
|
|
bool isStrL = (typeL == "String");
|
|
bool isStrR = (typeR == "String");
|
|
|
|
if (isStrL || isStrR) {
|
|
llvm::Value* strL = L;
|
|
llvm::Value* strR = R;
|
|
|
|
if (!isStrL) {
|
|
// Convert int to string
|
|
llvm::Function* intToStr = module->getFunction("int_to_str");
|
|
strL = builder->CreateCall(intToStr, {L}, "l_str");
|
|
}
|
|
if (!isStrR) {
|
|
// Convert int to string
|
|
llvm::Function* intToStr = module->getFunction("int_to_str");
|
|
strR = builder->CreateCall(intToStr, {R}, "r_str");
|
|
}
|
|
|
|
llvm::Function* strConcat = module->getFunction("str_concat");
|
|
lastValue = builder->CreateCall(strConcat, {strL, strR}, "concat");
|
|
lastClassName = "String";
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Check if we are doing pointer arithmetic or struct comparison (not supported yet)
|
|
if (!L->getType()->isIntegerTy() || !R->getType()->isIntegerTy()) {
|
|
std::cerr << "Binary operations only supported for integers." << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
if (node.op == "+") {
|
|
lastValue = builder->CreateAdd(L, R, "addtmp");
|
|
} else if (node.op == "-") {
|
|
lastValue = builder->CreateSub(L, R, "subtmp");
|
|
} else if (node.op == "*") {
|
|
lastValue = builder->CreateMul(L, R, "multmp");
|
|
} else if (node.op == "/") {
|
|
lastValue = builder->CreateSDiv(L, R, "divtmp");
|
|
} else if (node.op == "<") {
|
|
lastValue = builder->CreateICmpSLT(L, R, "cmptmp");
|
|
// Convert i1 to i32 for now, as everything is i32
|
|
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
|
|
} else if (node.op == ">") {
|
|
lastValue = builder->CreateICmpSGT(L, R, "cmptmp");
|
|
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
|
|
} else if (node.op == "==") {
|
|
lastValue = builder->CreateICmpEQ(L, R, "cmptmp");
|
|
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
|
|
} else {
|
|
std::cerr << "Unknown operator: " << node.op << std::endl;
|
|
lastValue = nullptr;
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(ReturnStmt& node) {
|
|
if (node.value) {
|
|
node.value->accept(*this);
|
|
if (lastValue) {
|
|
builder->CreateRet(lastValue);
|
|
}
|
|
} else {
|
|
builder->CreateRetVoid();
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(VarDeclStmt& node) {
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
llvm::Value* initVal;
|
|
std::string initClassName = "";
|
|
|
|
if (node.initializer) {
|
|
node.initializer->accept(*this);
|
|
initVal = lastValue;
|
|
initClassName = lastClassName;
|
|
} else {
|
|
initVal = llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true));
|
|
}
|
|
|
|
llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, node.name, initVal->getType());
|
|
builder->CreateStore(initVal, alloca);
|
|
namedValues[node.name] = alloca;
|
|
|
|
if (!initClassName.empty()) {
|
|
varTypes[node.name] = initClassName;
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(BlockStmt& node) {
|
|
for (const auto& stmt : node.statements) {
|
|
stmt->accept(*this);
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(IfStmt& node) {
|
|
node.condition->accept(*this);
|
|
llvm::Value* condV = lastValue;
|
|
if (!condV) return;
|
|
|
|
// Convert condition to bool (i1) by comparing not equal to 0
|
|
condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "ifcond");
|
|
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
|
|
llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(*context, "then", theFunction);
|
|
llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(*context, "else");
|
|
llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "ifcont");
|
|
|
|
builder->CreateCondBr(condV, thenBB, elseBB);
|
|
|
|
// Emit then value.
|
|
builder->SetInsertPoint(thenBB);
|
|
node.thenBranch->accept(*this);
|
|
|
|
// Only branch to merge if the block isn't already terminated (e.g. by return)
|
|
if (!builder->GetInsertBlock()->getTerminator()) {
|
|
builder->CreateBr(mergeBB);
|
|
}
|
|
|
|
// Emit else block.
|
|
theFunction->insert(theFunction->end(), elseBB);
|
|
builder->SetInsertPoint(elseBB);
|
|
if (node.elseBranch) {
|
|
node.elseBranch->accept(*this);
|
|
}
|
|
|
|
if (!builder->GetInsertBlock()->getTerminator()) {
|
|
builder->CreateBr(mergeBB);
|
|
}
|
|
|
|
// Emit merge block.
|
|
theFunction->insert(theFunction->end(), mergeBB);
|
|
builder->SetInsertPoint(mergeBB);
|
|
}
|
|
|
|
void CodeGen::visit(WhileStmt& node) {
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
|
|
llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction);
|
|
llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody");
|
|
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter");
|
|
|
|
// Jump to condition
|
|
builder->CreateBr(condBB);
|
|
|
|
// Condition Block
|
|
builder->SetInsertPoint(condBB);
|
|
node.condition->accept(*this);
|
|
llvm::Value* condV = lastValue;
|
|
if (!condV) return;
|
|
|
|
condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond");
|
|
builder->CreateCondBr(condV, bodyBB, afterBB);
|
|
|
|
// Body Block
|
|
theFunction->insert(theFunction->end(), bodyBB);
|
|
builder->SetInsertPoint(bodyBB);
|
|
|
|
breakStack.push_back(afterBB);
|
|
node.body->accept(*this);
|
|
breakStack.pop_back();
|
|
|
|
builder->CreateBr(condBB); // Loop back to condition
|
|
|
|
// After Block
|
|
theFunction->insert(theFunction->end(), afterBB);
|
|
builder->SetInsertPoint(afterBB);
|
|
}
|
|
|
|
void CodeGen::visit(ForStmt& node) {
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
|
|
// Emit init
|
|
if (node.init) {
|
|
node.init->accept(*this);
|
|
}
|
|
|
|
llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction);
|
|
llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody");
|
|
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter");
|
|
|
|
// Jump to condition
|
|
builder->CreateBr(condBB);
|
|
|
|
// Condition Block
|
|
builder->SetInsertPoint(condBB);
|
|
|
|
llvm::Value* condV = nullptr;
|
|
if (node.condition) {
|
|
node.condition->accept(*this);
|
|
condV = lastValue;
|
|
if (!condV) return;
|
|
condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond");
|
|
} else {
|
|
// Infinite loop if no condition
|
|
condV = llvm::ConstantInt::get(*context, llvm::APInt(1, 1, true));
|
|
}
|
|
|
|
builder->CreateCondBr(condV, bodyBB, afterBB);
|
|
|
|
// Body Block
|
|
theFunction->insert(theFunction->end(), bodyBB);
|
|
builder->SetInsertPoint(bodyBB);
|
|
|
|
breakStack.push_back(afterBB);
|
|
node.body->accept(*this);
|
|
breakStack.pop_back();
|
|
|
|
// Increment
|
|
if (node.increment) {
|
|
node.increment->accept(*this);
|
|
}
|
|
|
|
builder->CreateBr(condBB); // Loop back to condition
|
|
|
|
// After Block
|
|
theFunction->insert(theFunction->end(), afterBB);
|
|
builder->SetInsertPoint(afterBB);
|
|
}
|
|
|
|
void CodeGen::visit(ForInStmt& node) {
|
|
// 1. Evaluate collection (array)
|
|
node.collection->accept(*this);
|
|
llvm::Value* arrPtr = lastValue;
|
|
std::string arrayType = lastClassName;
|
|
|
|
if (!arrPtr) return;
|
|
|
|
// 2. Get Array Length
|
|
llvm::Function* lenFn = module->getFunction("sun_array_length");
|
|
llvm::Value* lenVal = builder->CreateCall(lenFn, {arrPtr}, "len");
|
|
|
|
// 3. Create Loop Variable (index)
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
llvm::AllocaInst* indexAlloca = createEntryBlockAlloca(theFunction, "index_" + node.variableName, llvm::Type::getInt32Ty(*context));
|
|
builder->CreateStore(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)), indexAlloca);
|
|
|
|
// 4. Create Loop Blocks
|
|
llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction);
|
|
llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody");
|
|
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter");
|
|
|
|
builder->CreateBr(condBB);
|
|
|
|
// 5. Condition: index < length
|
|
builder->SetInsertPoint(condBB);
|
|
llvm::Value* indexVal = builder->CreateLoad(llvm::Type::getInt32Ty(*context), indexAlloca, "index");
|
|
llvm::Value* condV = builder->CreateICmpSLT(indexVal, lenVal, "loopcond");
|
|
builder->CreateCondBr(condV, bodyBB, afterBB);
|
|
|
|
// 6. Body
|
|
theFunction->insert(theFunction->end(), bodyBB);
|
|
builder->SetInsertPoint(bodyBB);
|
|
|
|
// Fetch element at index
|
|
llvm::Function* getFn = module->getFunction("sun_array_get");
|
|
llvm::Value* elemPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem");
|
|
|
|
llvm::Value* elemVal;
|
|
std::string elemClassName;
|
|
|
|
if (arrayType == "StringArray") {
|
|
elemVal = elemPtr; // Already i8*
|
|
elemClassName = "String";
|
|
} else {
|
|
// Assume IntArray
|
|
elemVal = builder->CreatePtrToInt(elemPtr, llvm::Type::getInt32Ty(*context), "elemInt");
|
|
elemClassName = "";
|
|
}
|
|
|
|
// Create variable for element
|
|
// Check if variable already exists (shadowing?) or create new scope?
|
|
// For simplicity, create new alloca in entry block (or reuse if exists)
|
|
// But we are inside a loop, so we should probably create it once in entry block.
|
|
// However, we need to update it every iteration.
|
|
|
|
llvm::AllocaInst* varAlloca;
|
|
if (namedValues.find(node.variableName) == namedValues.end()) {
|
|
varAlloca = createEntryBlockAlloca(theFunction, node.variableName, elemVal->getType());
|
|
namedValues[node.variableName] = varAlloca;
|
|
} else {
|
|
varAlloca = namedValues[node.variableName];
|
|
}
|
|
|
|
builder->CreateStore(elemVal, varAlloca);
|
|
|
|
// Update type tracking
|
|
if (!elemClassName.empty()) {
|
|
varTypes[node.variableName] = elemClassName;
|
|
}
|
|
|
|
// Execute body
|
|
breakStack.push_back(afterBB);
|
|
node.body->accept(*this);
|
|
breakStack.pop_back();
|
|
|
|
// Increment index
|
|
llvm::Value* nextIndex = builder->CreateAdd(indexVal, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "nextindex");
|
|
builder->CreateStore(nextIndex, indexAlloca);
|
|
|
|
builder->CreateBr(condBB);
|
|
|
|
// After Block
|
|
theFunction->insert(theFunction->end(), afterBB);
|
|
builder->SetInsertPoint(afterBB);
|
|
}
|
|
|
|
void CodeGen::visit(ExpressionStmt& node) {
|
|
// Check if it's a print call (hack for now, since we don't have function calls as expressions yet)
|
|
// Actually, we don't have CallExpr yet.
|
|
// But we can check if the expression is a special "print" node?
|
|
// No, let's assume print is a function call.
|
|
// But we don't have CallExpr.
|
|
// Let's implement CallExpr? Or just handle it in Parser as a statement?
|
|
// The user asked for print() function.
|
|
// If I implement CallExpr, I can call any function.
|
|
// For now, let's assume print is handled via a special AST node or just CallExpr.
|
|
// I haven't implemented CallExpr yet.
|
|
// Let's implement CallExpr in AST.
|
|
node.expression->accept(*this);
|
|
}
|
|
|
|
void CodeGen::visit(FunctionDef& node) {
|
|
// Save current namedValues (which might be main's locals)
|
|
auto oldNamedValues = namedValues;
|
|
|
|
// 1. Define Function Type (Int32 -> Int32, Int32...)
|
|
std::vector<llvm::Type*> ints(node.args.size(), llvm::Type::getInt32Ty(*context));
|
|
llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), ints, false);
|
|
|
|
// 2. Create Function
|
|
llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, node.name, module.get());
|
|
|
|
// 3. Set Argument Names
|
|
unsigned idx = 0;
|
|
for (auto& arg : f->args()) {
|
|
arg.setName(node.args[idx++]);
|
|
}
|
|
|
|
// 4. Create Entry Block
|
|
llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", f);
|
|
builder->SetInsertPoint(bb);
|
|
|
|
// 5. Record Arguments in Symbol Table (Allocas)
|
|
namedValues.clear();
|
|
for (auto& arg : f->args()) {
|
|
// Create an alloca for this variable.
|
|
llvm::AllocaInst* alloca = createEntryBlockAlloca(f, std::string(arg.getName()), arg.getType());
|
|
|
|
// Store the initial value into the alloca.
|
|
builder->CreateStore(&arg, alloca);
|
|
|
|
// Add arguments to variable symbol table.
|
|
namedValues[std::string(arg.getName())] = alloca;
|
|
}
|
|
|
|
// 6. Generate Body
|
|
node.body->accept(*this);
|
|
|
|
// 7. Verify Function
|
|
llvm::verifyFunction(*f);
|
|
|
|
// Restore namedValues
|
|
namedValues = oldNamedValues;
|
|
}
|
|
|
|
void CodeGen::visit(ClassDef& node) {
|
|
// 1. Create Struct Type
|
|
std::vector<llvm::Type*> fieldTypes(node.fields.size(), llvm::Type::getInt32Ty(*context)); // All fields are i32 for now
|
|
llvm::StructType* structType = llvm::StructType::create(*context, fieldTypes, node.name);
|
|
|
|
classStructs[node.name] = structType;
|
|
classFields[node.name] = node.fields;
|
|
}
|
|
|
|
void CodeGen::visit(NewExpr& node) {
|
|
if (classStructs.find(node.className) == classStructs.end()) {
|
|
std::cerr << "Unknown class: " << node.className << std::endl;
|
|
lastValue = nullptr;
|
|
lastClassName = "";
|
|
return;
|
|
}
|
|
|
|
llvm::StructType* structType = classStructs[node.className];
|
|
|
|
// Allocate memory for the object (on heap ideally, but stack for now or malloc)
|
|
// For simplicity, let's use malloc via LLVM IR or just alloca if we want stack allocation.
|
|
// Let's use alloca for now (stack allocated objects).
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, "new_" + node.className, structType);
|
|
|
|
lastValue = alloca; // The "object" is a pointer to the struct
|
|
lastClassName = node.className;
|
|
}
|
|
|
|
void CodeGen::visit(GetPropExpr& node) {
|
|
// 1. Evaluate object expression
|
|
node.object->accept(*this);
|
|
llvm::Value* objectPtr = lastValue;
|
|
std::string className = lastClassName;
|
|
|
|
if (!objectPtr) return;
|
|
|
|
if (className.empty()) {
|
|
std::cerr << "Cannot determine class type for property access." << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
if (classFields.find(className) == classFields.end()) {
|
|
std::cerr << "Unknown class type: " << className << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
const auto& fields = classFields[className];
|
|
int fieldIndex = -1;
|
|
for (size_t i = 0; i < fields.size(); ++i) {
|
|
if (fields[i] == node.name) {
|
|
fieldIndex = i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (fieldIndex == -1) {
|
|
std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl;
|
|
lastValue = nullptr;
|
|
return;
|
|
}
|
|
|
|
// 3. Generate GEP and Load
|
|
std::vector<llvm::Value*> indices;
|
|
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); // Dereference pointer
|
|
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); // Field index
|
|
|
|
llvm::Type* structType = classStructs[className];
|
|
llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr");
|
|
lastValue = builder->CreateLoad(llvm::Type::getInt32Ty(*context), fieldPtr, "fieldval");
|
|
lastClassName = ""; // Field is int
|
|
}
|
|
|
|
void CodeGen::visit(SetPropExpr& node) {
|
|
// 1. Evaluate object
|
|
node.object->accept(*this);
|
|
llvm::Value* objectPtr = lastValue;
|
|
std::string className = lastClassName;
|
|
|
|
if (!objectPtr) return;
|
|
|
|
// 2. Evaluate value
|
|
node.value->accept(*this);
|
|
llvm::Value* val = lastValue;
|
|
|
|
if (className.empty()) {
|
|
std::cerr << "Cannot determine class type for property set." << std::endl;
|
|
return;
|
|
}
|
|
|
|
const auto& fields = classFields[className];
|
|
int fieldIndex = -1;
|
|
for (size_t i = 0; i < fields.size(); ++i) {
|
|
if (fields[i] == node.name) {
|
|
fieldIndex = i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (fieldIndex == -1) {
|
|
std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl;
|
|
return;
|
|
}
|
|
|
|
std::vector<llvm::Value*> indices;
|
|
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)));
|
|
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex)));
|
|
|
|
llvm::Type* structType = classStructs[className];
|
|
llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr");
|
|
builder->CreateStore(val, fieldPtr);
|
|
lastValue = val;
|
|
}
|
|
|
|
void CodeGen::visit(ArrayExpr& node) {
|
|
int size = node.elements.size();
|
|
llvm::Function* createFn = module->getFunction("sun_array_create");
|
|
llvm::Function* setFn = module->getFunction("sun_array_set");
|
|
|
|
llvm::Value* sizeVal = llvm::ConstantInt::get(*context, llvm::APInt(32, size));
|
|
llvm::Value* arrPtr = builder->CreateCall(createFn, {sizeVal}, "array");
|
|
|
|
std::string elemType = "IntArray"; // Default to IntArray
|
|
|
|
for (int i = 0; i < size; ++i) {
|
|
node.elements[i]->accept(*this);
|
|
llvm::Value* val = lastValue;
|
|
|
|
if (lastClassName == "String") {
|
|
elemType = "StringArray";
|
|
}
|
|
|
|
// Cast value to i8* (void*)
|
|
llvm::Value* voidVal;
|
|
if (val->getType()->isPointerTy()) {
|
|
voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
|
|
} else {
|
|
// Assume int (i32), cast to pointer sized int then inttoptr?
|
|
// Or just inttoptr directly.
|
|
voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
|
|
}
|
|
|
|
llvm::Value* indexVal = llvm::ConstantInt::get(*context, llvm::APInt(32, i));
|
|
builder->CreateCall(setFn, {arrPtr, indexVal, voidVal});
|
|
}
|
|
|
|
lastValue = arrPtr;
|
|
lastClassName = elemType;
|
|
}
|
|
|
|
void CodeGen::visit(IndexExpr& node) {
|
|
node.array->accept(*this);
|
|
llvm::Value* arrPtr = lastValue;
|
|
std::string arrayType = lastClassName;
|
|
|
|
node.index->accept(*this);
|
|
llvm::Value* indexVal = lastValue;
|
|
|
|
llvm::Function* getFn = module->getFunction("sun_array_get");
|
|
llvm::Value* valPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem");
|
|
|
|
if (arrayType == "StringArray") {
|
|
lastValue = valPtr; // Already i8*
|
|
lastClassName = "String";
|
|
} else if (arrayType == "CSVArray") {
|
|
lastValue = valPtr; // It's a void* (pointer to array)
|
|
lastClassName = "StringArray";
|
|
} else {
|
|
// Assume IntArray, cast back to i32
|
|
lastValue = builder->CreatePtrToInt(valPtr, llvm::Type::getInt32Ty(*context), "elemInt");
|
|
lastClassName = "";
|
|
}
|
|
}
|
|
|
|
void CodeGen::visit(ArrayAssignExpr& node) {
|
|
node.array->accept(*this);
|
|
llvm::Value* arrPtr = lastValue;
|
|
|
|
node.index->accept(*this);
|
|
llvm::Value* indexVal = lastValue;
|
|
|
|
node.value->accept(*this);
|
|
llvm::Value* val = lastValue;
|
|
|
|
llvm::Function* setFn = module->getFunction("sun_array_set");
|
|
|
|
// Cast value to i8*
|
|
llvm::Value* voidVal;
|
|
if (val->getType()->isPointerTy()) {
|
|
voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
|
|
} else {
|
|
voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
|
|
}
|
|
|
|
builder->CreateCall(setFn, {arrPtr, indexVal, voidVal});
|
|
lastValue = val;
|
|
}
|
|
|
|
void CodeGen::visit(SwitchStmt& node) {
|
|
// 1. Evaluate condition
|
|
node.condition->accept(*this);
|
|
llvm::Value* condVal = lastValue;
|
|
if (!condVal) return;
|
|
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
|
|
// 2. Create Merge Block (after switch)
|
|
llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "switchmerge");
|
|
|
|
// Push mergeBB to breakStack so 'break' jumps here
|
|
breakStack.push_back(mergeBB);
|
|
|
|
// 3. Create Default Block
|
|
llvm::BasicBlock* defaultBB = llvm::BasicBlock::Create(*context, "switchdefault");
|
|
|
|
// 4. Create Blocks for Cases
|
|
std::vector<llvm::BasicBlock*> caseBBs;
|
|
for (size_t i = 0; i < node.cases.size(); ++i) {
|
|
caseBBs.push_back(llvm::BasicBlock::Create(*context, "case" + std::to_string(i)));
|
|
}
|
|
|
|
// 5. Generate Comparisons (If-Else Chain)
|
|
llvm::BasicBlock* currentTestBB = builder->GetInsertBlock();
|
|
|
|
for (size_t i = 0; i < node.cases.size(); ++i) {
|
|
// Evaluate case value
|
|
node.cases[i].value->accept(*this);
|
|
llvm::Value* caseVal = lastValue;
|
|
|
|
// Compare
|
|
llvm::Value* cmp = builder->CreateICmpEQ(condVal, caseVal, "casecmp");
|
|
|
|
llvm::BasicBlock* nextTestBB;
|
|
if (i < node.cases.size() - 1) {
|
|
nextTestBB = llvm::BasicBlock::Create(*context, "nexttest", theFunction);
|
|
} else {
|
|
nextTestBB = defaultBB;
|
|
}
|
|
|
|
builder->CreateCondBr(cmp, caseBBs[i], nextTestBB);
|
|
|
|
if (i < node.cases.size() - 1) {
|
|
builder->SetInsertPoint(nextTestBB);
|
|
currentTestBB = nextTestBB;
|
|
}
|
|
}
|
|
|
|
if (node.cases.empty()) {
|
|
builder->CreateBr(defaultBB);
|
|
}
|
|
|
|
// 6. Generate Case Bodies
|
|
for (size_t i = 0; i < node.cases.size(); ++i) {
|
|
theFunction->insert(theFunction->end(), caseBBs[i]);
|
|
builder->SetInsertPoint(caseBBs[i]);
|
|
|
|
node.cases[i].body->accept(*this);
|
|
|
|
// Fallthrough
|
|
if (!builder->GetInsertBlock()->getTerminator()) {
|
|
if (i < node.cases.size() - 1) {
|
|
builder->CreateBr(caseBBs[i+1]);
|
|
} else {
|
|
builder->CreateBr(defaultBB);
|
|
}
|
|
}
|
|
}
|
|
|
|
// 7. Generate Default Body
|
|
theFunction->insert(theFunction->end(), defaultBB);
|
|
builder->SetInsertPoint(defaultBB);
|
|
if (node.defaultCase) {
|
|
node.defaultCase->accept(*this);
|
|
}
|
|
|
|
if (!builder->GetInsertBlock()->getTerminator()) {
|
|
builder->CreateBr(mergeBB);
|
|
}
|
|
|
|
// 8. Finish
|
|
breakStack.pop_back();
|
|
theFunction->insert(theFunction->end(), mergeBB);
|
|
builder->SetInsertPoint(mergeBB);
|
|
}
|
|
|
|
void CodeGen::visit(BreakStmt& node) {
|
|
if (breakStack.empty()) {
|
|
std::cerr << "Error: break statement outside loop or switch" << std::endl;
|
|
return;
|
|
}
|
|
|
|
builder->CreateBr(breakStack.back());
|
|
|
|
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
|
|
llvm::BasicBlock* deadBB = llvm::BasicBlock::Create(*context, "dead");
|
|
theFunction->insert(theFunction->end(), deadBB);
|
|
builder->SetInsertPoint(deadBB);
|
|
}
|