sunlang/src/codegen.cpp

934 lines
33 KiB
C++

#include "codegen.h"
#include <iostream>
CodeGen::CodeGen() {
context = std::make_unique<llvm::LLVMContext>();
module = std::make_unique<llvm::Module>("sun_module", *context);
builder = std::make_unique<llvm::IRBuilder<>>(*context);
// Declare runtime functions
// void print_int(int)
llvm::FunctionType* printIntType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::Type::getInt32Ty(*context)}, false);
llvm::Function::Create(printIntType, llvm::Function::ExternalLinkage, "print_int", module.get());
// void print_string(char*)
llvm::FunctionType* printStrType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
llvm::Function::Create(printStrType, llvm::Function::ExternalLinkage, "print_string", module.get());
// char* int_to_str(int)
llvm::FunctionType* intToStrType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false);
llvm::Function::Create(intToStrType, llvm::Function::ExternalLinkage, "int_to_str", module.get());
// char* str_concat(char*, char*)
llvm::FunctionType* strConcatType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
llvm::Function::Create(strConcatType, llvm::Function::ExternalLinkage, "str_concat", module.get());
// void* sun_array_create(int size)
llvm::FunctionType* arrayCreateType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false);
llvm::Function::Create(arrayCreateType, llvm::Function::ExternalLinkage, "sun_array_create", module.get());
// void sun_array_set(void* arr, int index, void* value)
llvm::FunctionType* arraySetType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
llvm::Function::Create(arraySetType, llvm::Function::ExternalLinkage, "sun_array_set", module.get());
// void* sun_array_get(void* arr, int index)
llvm::FunctionType* arrayGetType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context)}, false);
llvm::Function::Create(arrayGetType, llvm::Function::ExternalLinkage, "sun_array_get", module.get());
// int sun_array_length(void* arr)
llvm::FunctionType* arrayLenType = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false);
llvm::Function::Create(arrayLenType, llvm::Function::ExternalLinkage, "sun_array_length", module.get());
// void* sun_read_csv(char* filename, char* separator, char* quote)
llvm::FunctionType* readCsvType = llvm::FunctionType::get(
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0),
{
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), // filename
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), // separator
llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0) // quote
},
false
);
llvm::Function::Create(readCsvType, llvm::Function::ExternalLinkage, "sun_read_csv", module.get());
}
llvm::AllocaInst* CodeGen::createEntryBlockAlloca(llvm::Function* theFunction, const std::string& varName, llvm::Type* type) {
llvm::IRBuilder<> tmpBuilder(&theFunction->getEntryBlock(), theFunction->getEntryBlock().begin());
return tmpBuilder.CreateAlloca(type, nullptr, varName);
}
void CodeGen::generate(ASTNode& node) {
if (dynamic_cast<FunctionDef*>(&node) || dynamic_cast<ClassDef*>(&node)) {
if (mainFunction && builder->GetInsertBlock()) {
mainInsertBlock = builder->GetInsertBlock();
}
node.accept(*this);
} else {
if (!mainFunction) {
llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), false);
mainFunction = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "main", module.get());
llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", mainFunction);
builder->SetInsertPoint(bb);
} else {
if (mainInsertBlock) {
builder->SetInsertPoint(mainInsertBlock);
}
}
node.accept(*this);
mainInsertBlock = builder->GetInsertBlock();
}
}
void CodeGen::finish() {
if (mainFunction) {
if (mainInsertBlock && !mainInsertBlock->getTerminator()) {
builder->SetInsertPoint(mainInsertBlock);
builder->CreateRet(llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)));
}
}
}
void CodeGen::print() {
module->print(llvm::outs(), nullptr);
}
void CodeGen::visit(NumberExpr& node) {
lastValue = llvm::ConstantInt::get(*context, llvm::APInt(32, node.value, true));
lastClassName = ""; // Not a class
}
void CodeGen::visit(StringExpr& node) {
lastValue = builder->CreateGlobalString(node.value);
lastClassName = "String"; // Special marker for string
}
void CodeGen::visit(VariableExpr& node) {
if (namedValues.find(node.name) == namedValues.end()) {
std::cerr << "Unknown variable name: " << node.name << std::endl;
lastValue = nullptr;
lastClassName = "";
return;
}
// Load the value from the alloca
llvm::AllocaInst* alloca = namedValues[node.name];
lastValue = builder->CreateLoad(alloca->getAllocatedType(), alloca, node.name.c_str());
if (varTypes.find(node.name) != varTypes.end()) {
lastClassName = varTypes[node.name];
} else {
lastClassName = "";
}
}
void CodeGen::visit(AssignExpr& node) {
node.value->accept(*this);
llvm::Value* val = lastValue;
std::string valClassName = lastClassName; // Capture type of value
if (!val) return;
if (namedValues.find(node.name) == namedValues.end()) {
std::cerr << "Unknown variable name: " << node.name << std::endl;
lastValue = nullptr;
return;
}
llvm::AllocaInst* alloca = namedValues[node.name];
builder->CreateStore(val, alloca);
lastValue = val;
// Update type tracking
if (!valClassName.empty()) {
varTypes[node.name] = valClassName;
lastClassName = valClassName;
}
}
void CodeGen::visit(CallExpr& node) {
// Handle "readcsv" specially
if (node.callee == "readcsv") {
if (node.args.size() < 1 || node.args.size() > 3) {
std::cerr << "readcsv() takes 1 to 3 arguments." << std::endl;
lastValue = nullptr;
return;
}
node.args[0]->accept(*this);
llvm::Value* filenameVal = lastValue;
llvm::Value* sepVal;
if (node.args.size() >= 2) {
node.args[1]->accept(*this);
sepVal = lastValue;
} else {
sepVal = builder->CreateGlobalStringPtr(",");
}
llvm::Value* quoteVal;
if (node.args.size() >= 3) {
node.args[2]->accept(*this);
quoteVal = lastValue;
} else {
quoteVal = builder->CreateGlobalStringPtr("\"");
}
llvm::Function* readCsvFn = module->getFunction("sun_read_csv");
lastValue = builder->CreateCall(readCsvFn, {filenameVal, sepVal, quoteVal}, "csv_data");
lastClassName = "CSVArray";
return;
}
// Handle "len" specially
if (node.callee == "len") {
if (node.args.size() != 1) {
std::cerr << "len() takes exactly 1 argument." << std::endl;
lastValue = nullptr;
return;
}
node.args[0]->accept(*this);
llvm::Value* arrVal = lastValue;
// Cast to i8* if needed (it should be i8* already as void*)
if (arrVal->getType() != llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)) {
arrVal = builder->CreateBitCast(arrVal, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
}
llvm::Function* lenFn = module->getFunction("sun_array_length");
lastValue = builder->CreateCall(lenFn, {arrVal}, "len");
lastClassName = ""; // Returns int
return;
}
// Handle "print" specially or as a normal function
if (node.callee == "print") {
if (node.args.size() != 1) {
std::cerr << "print() takes exactly 1 argument." << std::endl;
lastValue = nullptr;
return;
}
node.args[0]->accept(*this);
llvm::Value* argVal = lastValue;
std::string argType = lastClassName;
if (argType == "String") {
llvm::Function* printStr = module->getFunction("print_string");
builder->CreateCall(printStr, {argVal});
} else {
// Assume int
llvm::Function* printInt = module->getFunction("print_int");
builder->CreateCall(printInt, {argVal});
}
lastValue = nullptr; // print returns void
return;
}
llvm::Function* calleeF = module->getFunction(node.callee);
if (!calleeF) {
std::cerr << "Unknown function referenced: " << node.callee << std::endl;
lastValue = nullptr;
return;
}
if (calleeF->arg_size() != node.args.size()) {
std::cerr << "Incorrect # arguments passed." << std::endl;
lastValue = nullptr;
return;
}
std::vector<llvm::Value*> argsV;
for (unsigned i = 0, e = node.args.size(); i != e; ++i) {
node.args[i]->accept(*this);
if (!lastValue) return;
argsV.push_back(lastValue);
}
lastValue = builder->CreateCall(calleeF, argsV, "calltmp");
// We don't know the return type class name easily without function signatures in symbol table.
// For now, assume int.
lastClassName = "";
}
void CodeGen::visit(BinaryExpr& node) {
node.left->accept(*this);
llvm::Value* L = lastValue;
std::string typeL = lastClassName;
node.right->accept(*this);
llvm::Value* R = lastValue;
std::string typeR = lastClassName;
if (!L || !R) {
lastValue = nullptr;
return;
}
// String Concatenation
if (node.op == "+") {
bool isStrL = (typeL == "String");
bool isStrR = (typeR == "String");
if (isStrL || isStrR) {
llvm::Value* strL = L;
llvm::Value* strR = R;
if (!isStrL) {
// Convert int to string
llvm::Function* intToStr = module->getFunction("int_to_str");
strL = builder->CreateCall(intToStr, {L}, "l_str");
}
if (!isStrR) {
// Convert int to string
llvm::Function* intToStr = module->getFunction("int_to_str");
strR = builder->CreateCall(intToStr, {R}, "r_str");
}
llvm::Function* strConcat = module->getFunction("str_concat");
lastValue = builder->CreateCall(strConcat, {strL, strR}, "concat");
lastClassName = "String";
return;
}
}
// Check if we are doing pointer arithmetic or struct comparison (not supported yet)
if (!L->getType()->isIntegerTy() || !R->getType()->isIntegerTy()) {
std::cerr << "Binary operations only supported for integers." << std::endl;
lastValue = nullptr;
return;
}
if (node.op == "+") {
lastValue = builder->CreateAdd(L, R, "addtmp");
} else if (node.op == "-") {
lastValue = builder->CreateSub(L, R, "subtmp");
} else if (node.op == "*") {
lastValue = builder->CreateMul(L, R, "multmp");
} else if (node.op == "/") {
lastValue = builder->CreateSDiv(L, R, "divtmp");
} else if (node.op == "<") {
lastValue = builder->CreateICmpSLT(L, R, "cmptmp");
// Convert i1 to i32 for now, as everything is i32
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
} else if (node.op == ">") {
lastValue = builder->CreateICmpSGT(L, R, "cmptmp");
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
} else if (node.op == "==") {
lastValue = builder->CreateICmpEQ(L, R, "cmptmp");
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
} else {
std::cerr << "Unknown operator: " << node.op << std::endl;
lastValue = nullptr;
}
}
void CodeGen::visit(ReturnStmt& node) {
if (node.value) {
node.value->accept(*this);
if (lastValue) {
builder->CreateRet(lastValue);
}
} else {
builder->CreateRetVoid();
}
}
void CodeGen::visit(VarDeclStmt& node) {
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
llvm::Value* initVal;
std::string initClassName = "";
if (node.initializer) {
node.initializer->accept(*this);
initVal = lastValue;
initClassName = lastClassName;
} else {
initVal = llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true));
}
llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, node.name, initVal->getType());
builder->CreateStore(initVal, alloca);
namedValues[node.name] = alloca;
if (!initClassName.empty()) {
varTypes[node.name] = initClassName;
}
}
void CodeGen::visit(BlockStmt& node) {
for (const auto& stmt : node.statements) {
stmt->accept(*this);
}
}
void CodeGen::visit(IfStmt& node) {
node.condition->accept(*this);
llvm::Value* condV = lastValue;
if (!condV) return;
// Convert condition to bool (i1) by comparing not equal to 0
condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "ifcond");
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(*context, "then", theFunction);
llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(*context, "else");
llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "ifcont");
builder->CreateCondBr(condV, thenBB, elseBB);
// Emit then value.
builder->SetInsertPoint(thenBB);
node.thenBranch->accept(*this);
// Only branch to merge if the block isn't already terminated (e.g. by return)
if (!builder->GetInsertBlock()->getTerminator()) {
builder->CreateBr(mergeBB);
}
// Emit else block.
theFunction->insert(theFunction->end(), elseBB);
builder->SetInsertPoint(elseBB);
if (node.elseBranch) {
node.elseBranch->accept(*this);
}
if (!builder->GetInsertBlock()->getTerminator()) {
builder->CreateBr(mergeBB);
}
// Emit merge block.
theFunction->insert(theFunction->end(), mergeBB);
builder->SetInsertPoint(mergeBB);
}
void CodeGen::visit(WhileStmt& node) {
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction);
llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody");
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter");
// Jump to condition
builder->CreateBr(condBB);
// Condition Block
builder->SetInsertPoint(condBB);
node.condition->accept(*this);
llvm::Value* condV = lastValue;
if (!condV) return;
condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond");
builder->CreateCondBr(condV, bodyBB, afterBB);
// Body Block
theFunction->insert(theFunction->end(), bodyBB);
builder->SetInsertPoint(bodyBB);
breakStack.push_back(afterBB);
node.body->accept(*this);
breakStack.pop_back();
builder->CreateBr(condBB); // Loop back to condition
// After Block
theFunction->insert(theFunction->end(), afterBB);
builder->SetInsertPoint(afterBB);
}
void CodeGen::visit(ForStmt& node) {
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
// Emit init
if (node.init) {
node.init->accept(*this);
}
llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction);
llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody");
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter");
// Jump to condition
builder->CreateBr(condBB);
// Condition Block
builder->SetInsertPoint(condBB);
llvm::Value* condV = nullptr;
if (node.condition) {
node.condition->accept(*this);
condV = lastValue;
if (!condV) return;
condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond");
} else {
// Infinite loop if no condition
condV = llvm::ConstantInt::get(*context, llvm::APInt(1, 1, true));
}
builder->CreateCondBr(condV, bodyBB, afterBB);
// Body Block
theFunction->insert(theFunction->end(), bodyBB);
builder->SetInsertPoint(bodyBB);
breakStack.push_back(afterBB);
node.body->accept(*this);
breakStack.pop_back();
// Increment
if (node.increment) {
node.increment->accept(*this);
}
builder->CreateBr(condBB); // Loop back to condition
// After Block
theFunction->insert(theFunction->end(), afterBB);
builder->SetInsertPoint(afterBB);
}
void CodeGen::visit(ForInStmt& node) {
// 1. Evaluate collection (array)
node.collection->accept(*this);
llvm::Value* arrPtr = lastValue;
std::string arrayType = lastClassName;
if (!arrPtr) return;
// 2. Get Array Length
llvm::Function* lenFn = module->getFunction("sun_array_length");
llvm::Value* lenVal = builder->CreateCall(lenFn, {arrPtr}, "len");
// 3. Create Loop Variable (index)
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
llvm::AllocaInst* indexAlloca = createEntryBlockAlloca(theFunction, "index_" + node.variableName, llvm::Type::getInt32Ty(*context));
builder->CreateStore(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)), indexAlloca);
// 4. Create Loop Blocks
llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction);
llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody");
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter");
builder->CreateBr(condBB);
// 5. Condition: index < length
builder->SetInsertPoint(condBB);
llvm::Value* indexVal = builder->CreateLoad(llvm::Type::getInt32Ty(*context), indexAlloca, "index");
llvm::Value* condV = builder->CreateICmpSLT(indexVal, lenVal, "loopcond");
builder->CreateCondBr(condV, bodyBB, afterBB);
// 6. Body
theFunction->insert(theFunction->end(), bodyBB);
builder->SetInsertPoint(bodyBB);
// Fetch element at index
llvm::Function* getFn = module->getFunction("sun_array_get");
llvm::Value* elemPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem");
llvm::Value* elemVal;
std::string elemClassName;
if (arrayType == "StringArray") {
elemVal = elemPtr; // Already i8*
elemClassName = "String";
} else {
// Assume IntArray
elemVal = builder->CreatePtrToInt(elemPtr, llvm::Type::getInt32Ty(*context), "elemInt");
elemClassName = "";
}
// Create variable for element
// Check if variable already exists (shadowing?) or create new scope?
// For simplicity, create new alloca in entry block (or reuse if exists)
// But we are inside a loop, so we should probably create it once in entry block.
// However, we need to update it every iteration.
llvm::AllocaInst* varAlloca;
if (namedValues.find(node.variableName) == namedValues.end()) {
varAlloca = createEntryBlockAlloca(theFunction, node.variableName, elemVal->getType());
namedValues[node.variableName] = varAlloca;
} else {
varAlloca = namedValues[node.variableName];
}
builder->CreateStore(elemVal, varAlloca);
// Update type tracking
if (!elemClassName.empty()) {
varTypes[node.variableName] = elemClassName;
}
// Execute body
breakStack.push_back(afterBB);
node.body->accept(*this);
breakStack.pop_back();
// Increment index
llvm::Value* nextIndex = builder->CreateAdd(indexVal, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "nextindex");
builder->CreateStore(nextIndex, indexAlloca);
builder->CreateBr(condBB);
// After Block
theFunction->insert(theFunction->end(), afterBB);
builder->SetInsertPoint(afterBB);
}
void CodeGen::visit(ExpressionStmt& node) {
// Check if it's a print call (hack for now, since we don't have function calls as expressions yet)
// Actually, we don't have CallExpr yet.
// But we can check if the expression is a special "print" node?
// No, let's assume print is a function call.
// But we don't have CallExpr.
// Let's implement CallExpr? Or just handle it in Parser as a statement?
// The user asked for print() function.
// If I implement CallExpr, I can call any function.
// For now, let's assume print is handled via a special AST node or just CallExpr.
// I haven't implemented CallExpr yet.
// Let's implement CallExpr in AST.
node.expression->accept(*this);
}
void CodeGen::visit(FunctionDef& node) {
// Save current namedValues (which might be main's locals)
auto oldNamedValues = namedValues;
// 1. Define Function Type (Int32 -> Int32, Int32...)
std::vector<llvm::Type*> ints(node.args.size(), llvm::Type::getInt32Ty(*context));
llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), ints, false);
// 2. Create Function
llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, node.name, module.get());
// 3. Set Argument Names
unsigned idx = 0;
for (auto& arg : f->args()) {
arg.setName(node.args[idx++]);
}
// 4. Create Entry Block
llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", f);
builder->SetInsertPoint(bb);
// 5. Record Arguments in Symbol Table (Allocas)
namedValues.clear();
for (auto& arg : f->args()) {
// Create an alloca for this variable.
llvm::AllocaInst* alloca = createEntryBlockAlloca(f, std::string(arg.getName()), arg.getType());
// Store the initial value into the alloca.
builder->CreateStore(&arg, alloca);
// Add arguments to variable symbol table.
namedValues[std::string(arg.getName())] = alloca;
}
// 6. Generate Body
node.body->accept(*this);
// 7. Verify Function
llvm::verifyFunction(*f);
// Restore namedValues
namedValues = oldNamedValues;
}
void CodeGen::visit(ClassDef& node) {
// 1. Create Struct Type
std::vector<llvm::Type*> fieldTypes(node.fields.size(), llvm::Type::getInt32Ty(*context)); // All fields are i32 for now
llvm::StructType* structType = llvm::StructType::create(*context, fieldTypes, node.name);
classStructs[node.name] = structType;
classFields[node.name] = node.fields;
}
void CodeGen::visit(NewExpr& node) {
if (classStructs.find(node.className) == classStructs.end()) {
std::cerr << "Unknown class: " << node.className << std::endl;
lastValue = nullptr;
lastClassName = "";
return;
}
llvm::StructType* structType = classStructs[node.className];
// Allocate memory for the object (on heap ideally, but stack for now or malloc)
// For simplicity, let's use malloc via LLVM IR or just alloca if we want stack allocation.
// Let's use alloca for now (stack allocated objects).
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, "new_" + node.className, structType);
lastValue = alloca; // The "object" is a pointer to the struct
lastClassName = node.className;
}
void CodeGen::visit(GetPropExpr& node) {
// 1. Evaluate object expression
node.object->accept(*this);
llvm::Value* objectPtr = lastValue;
std::string className = lastClassName;
if (!objectPtr) return;
if (className.empty()) {
std::cerr << "Cannot determine class type for property access." << std::endl;
lastValue = nullptr;
return;
}
if (classFields.find(className) == classFields.end()) {
std::cerr << "Unknown class type: " << className << std::endl;
lastValue = nullptr;
return;
}
const auto& fields = classFields[className];
int fieldIndex = -1;
for (size_t i = 0; i < fields.size(); ++i) {
if (fields[i] == node.name) {
fieldIndex = i;
break;
}
}
if (fieldIndex == -1) {
std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl;
lastValue = nullptr;
return;
}
// 3. Generate GEP and Load
std::vector<llvm::Value*> indices;
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); // Dereference pointer
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); // Field index
llvm::Type* structType = classStructs[className];
llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr");
lastValue = builder->CreateLoad(llvm::Type::getInt32Ty(*context), fieldPtr, "fieldval");
lastClassName = ""; // Field is int
}
void CodeGen::visit(SetPropExpr& node) {
// 1. Evaluate object
node.object->accept(*this);
llvm::Value* objectPtr = lastValue;
std::string className = lastClassName;
if (!objectPtr) return;
// 2. Evaluate value
node.value->accept(*this);
llvm::Value* val = lastValue;
if (className.empty()) {
std::cerr << "Cannot determine class type for property set." << std::endl;
return;
}
const auto& fields = classFields[className];
int fieldIndex = -1;
for (size_t i = 0; i < fields.size(); ++i) {
if (fields[i] == node.name) {
fieldIndex = i;
break;
}
}
if (fieldIndex == -1) {
std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl;
return;
}
std::vector<llvm::Value*> indices;
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)));
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex)));
llvm::Type* structType = classStructs[className];
llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr");
builder->CreateStore(val, fieldPtr);
lastValue = val;
}
void CodeGen::visit(ArrayExpr& node) {
int size = node.elements.size();
llvm::Function* createFn = module->getFunction("sun_array_create");
llvm::Function* setFn = module->getFunction("sun_array_set");
llvm::Value* sizeVal = llvm::ConstantInt::get(*context, llvm::APInt(32, size));
llvm::Value* arrPtr = builder->CreateCall(createFn, {sizeVal}, "array");
std::string elemType = "IntArray"; // Default to IntArray
for (int i = 0; i < size; ++i) {
node.elements[i]->accept(*this);
llvm::Value* val = lastValue;
if (lastClassName == "String") {
elemType = "StringArray";
}
// Cast value to i8* (void*)
llvm::Value* voidVal;
if (val->getType()->isPointerTy()) {
voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
} else {
// Assume int (i32), cast to pointer sized int then inttoptr?
// Or just inttoptr directly.
voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
}
llvm::Value* indexVal = llvm::ConstantInt::get(*context, llvm::APInt(32, i));
builder->CreateCall(setFn, {arrPtr, indexVal, voidVal});
}
lastValue = arrPtr;
lastClassName = elemType;
}
void CodeGen::visit(IndexExpr& node) {
node.array->accept(*this);
llvm::Value* arrPtr = lastValue;
std::string arrayType = lastClassName;
node.index->accept(*this);
llvm::Value* indexVal = lastValue;
llvm::Function* getFn = module->getFunction("sun_array_get");
llvm::Value* valPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem");
if (arrayType == "StringArray") {
lastValue = valPtr; // Already i8*
lastClassName = "String";
} else if (arrayType == "CSVArray") {
lastValue = valPtr; // It's a void* (pointer to array)
lastClassName = "StringArray";
} else {
// Assume IntArray, cast back to i32
lastValue = builder->CreatePtrToInt(valPtr, llvm::Type::getInt32Ty(*context), "elemInt");
lastClassName = "";
}
}
void CodeGen::visit(ArrayAssignExpr& node) {
node.array->accept(*this);
llvm::Value* arrPtr = lastValue;
node.index->accept(*this);
llvm::Value* indexVal = lastValue;
node.value->accept(*this);
llvm::Value* val = lastValue;
llvm::Function* setFn = module->getFunction("sun_array_set");
// Cast value to i8*
llvm::Value* voidVal;
if (val->getType()->isPointerTy()) {
voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
} else {
voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
}
builder->CreateCall(setFn, {arrPtr, indexVal, voidVal});
lastValue = val;
}
void CodeGen::visit(SwitchStmt& node) {
// 1. Evaluate condition
node.condition->accept(*this);
llvm::Value* condVal = lastValue;
if (!condVal) return;
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
// 2. Create Merge Block (after switch)
llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "switchmerge");
// Push mergeBB to breakStack so 'break' jumps here
breakStack.push_back(mergeBB);
// 3. Create Default Block
llvm::BasicBlock* defaultBB = llvm::BasicBlock::Create(*context, "switchdefault");
// 4. Create Blocks for Cases
std::vector<llvm::BasicBlock*> caseBBs;
for (size_t i = 0; i < node.cases.size(); ++i) {
caseBBs.push_back(llvm::BasicBlock::Create(*context, "case" + std::to_string(i)));
}
// 5. Generate Comparisons (If-Else Chain)
llvm::BasicBlock* currentTestBB = builder->GetInsertBlock();
for (size_t i = 0; i < node.cases.size(); ++i) {
// Evaluate case value
node.cases[i].value->accept(*this);
llvm::Value* caseVal = lastValue;
// Compare
llvm::Value* cmp = builder->CreateICmpEQ(condVal, caseVal, "casecmp");
llvm::BasicBlock* nextTestBB;
if (i < node.cases.size() - 1) {
nextTestBB = llvm::BasicBlock::Create(*context, "nexttest", theFunction);
} else {
nextTestBB = defaultBB;
}
builder->CreateCondBr(cmp, caseBBs[i], nextTestBB);
if (i < node.cases.size() - 1) {
builder->SetInsertPoint(nextTestBB);
currentTestBB = nextTestBB;
}
}
if (node.cases.empty()) {
builder->CreateBr(defaultBB);
}
// 6. Generate Case Bodies
for (size_t i = 0; i < node.cases.size(); ++i) {
theFunction->insert(theFunction->end(), caseBBs[i]);
builder->SetInsertPoint(caseBBs[i]);
node.cases[i].body->accept(*this);
// Fallthrough
if (!builder->GetInsertBlock()->getTerminator()) {
if (i < node.cases.size() - 1) {
builder->CreateBr(caseBBs[i+1]);
} else {
builder->CreateBr(defaultBB);
}
}
}
// 7. Generate Default Body
theFunction->insert(theFunction->end(), defaultBB);
builder->SetInsertPoint(defaultBB);
if (node.defaultCase) {
node.defaultCase->accept(*this);
}
if (!builder->GetInsertBlock()->getTerminator()) {
builder->CreateBr(mergeBB);
}
// 8. Finish
breakStack.pop_back();
theFunction->insert(theFunction->end(), mergeBB);
builder->SetInsertPoint(mergeBB);
}
void CodeGen::visit(BreakStmt& node) {
if (breakStack.empty()) {
std::cerr << "Error: break statement outside loop or switch" << std::endl;
return;
}
builder->CreateBr(breakStack.back());
llvm::Function* theFunction = builder->GetInsertBlock()->getParent();
llvm::BasicBlock* deadBB = llvm::BasicBlock::Create(*context, "dead");
theFunction->insert(theFunction->end(), deadBB);
builder->SetInsertPoint(deadBB);
}