#include "codegen.h" #include CodeGen::CodeGen() { context = std::make_unique(); module = std::make_unique("sun_module", *context); builder = std::make_unique>(*context); // Declare runtime functions // void print_int(int) llvm::FunctionType* printIntType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::Type::getInt32Ty(*context)}, false); llvm::Function::Create(printIntType, llvm::Function::ExternalLinkage, "print_int", module.get()); // void print_string(char*) llvm::FunctionType* printStrType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); llvm::Function::Create(printStrType, llvm::Function::ExternalLinkage, "print_string", module.get()); // char* int_to_str(int) llvm::FunctionType* intToStrType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false); llvm::Function::Create(intToStrType, llvm::Function::ExternalLinkage, "int_to_str", module.get()); // char* str_concat(char*, char*) llvm::FunctionType* strConcatType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); llvm::Function::Create(strConcatType, llvm::Function::ExternalLinkage, "str_concat", module.get()); // void* sun_array_create(int size) llvm::FunctionType* arrayCreateType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false); llvm::Function::Create(arrayCreateType, llvm::Function::ExternalLinkage, "sun_array_create", module.get()); // void sun_array_set(void* arr, int index, void* value) llvm::FunctionType* arraySetType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); llvm::Function::Create(arraySetType, llvm::Function::ExternalLinkage, "sun_array_set", module.get()); // void* sun_array_get(void* arr, int index) llvm::FunctionType* arrayGetType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context)}, false); llvm::Function::Create(arrayGetType, llvm::Function::ExternalLinkage, "sun_array_get", module.get()); // int sun_array_length(void* arr) llvm::FunctionType* arrayLenType = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); llvm::Function::Create(arrayLenType, llvm::Function::ExternalLinkage, "sun_array_length", module.get()); // void* sun_read_csv(char* filename, char* separator, char* quote) llvm::FunctionType* readCsvType = llvm::FunctionType::get( llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), { llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), // filename llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), // separator llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0) // quote }, false ); llvm::Function::Create(readCsvType, llvm::Function::ExternalLinkage, "sun_read_csv", module.get()); } llvm::AllocaInst* CodeGen::createEntryBlockAlloca(llvm::Function* theFunction, const std::string& varName, llvm::Type* type) { llvm::IRBuilder<> tmpBuilder(&theFunction->getEntryBlock(), theFunction->getEntryBlock().begin()); return tmpBuilder.CreateAlloca(type, nullptr, varName); } void CodeGen::generate(ASTNode& node) { if (dynamic_cast(&node) || dynamic_cast(&node)) { if (mainFunction && builder->GetInsertBlock()) { mainInsertBlock = builder->GetInsertBlock(); } node.accept(*this); } else { if (!mainFunction) { llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), false); mainFunction = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "main", module.get()); llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", mainFunction); builder->SetInsertPoint(bb); } else { if (mainInsertBlock) { builder->SetInsertPoint(mainInsertBlock); } } node.accept(*this); mainInsertBlock = builder->GetInsertBlock(); } } void CodeGen::finish() { if (mainFunction) { if (mainInsertBlock && !mainInsertBlock->getTerminator()) { builder->SetInsertPoint(mainInsertBlock); builder->CreateRet(llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true))); } } } void CodeGen::print() { module->print(llvm::outs(), nullptr); } void CodeGen::visit(NumberExpr& node) { lastValue = llvm::ConstantInt::get(*context, llvm::APInt(32, node.value, true)); lastClassName = ""; // Not a class } void CodeGen::visit(StringExpr& node) { lastValue = builder->CreateGlobalString(node.value); lastClassName = "String"; // Special marker for string } void CodeGen::visit(VariableExpr& node) { if (namedValues.find(node.name) == namedValues.end()) { std::cerr << "Unknown variable name: " << node.name << std::endl; lastValue = nullptr; lastClassName = ""; return; } // Load the value from the alloca llvm::AllocaInst* alloca = namedValues[node.name]; lastValue = builder->CreateLoad(alloca->getAllocatedType(), alloca, node.name.c_str()); if (varTypes.find(node.name) != varTypes.end()) { lastClassName = varTypes[node.name]; } else { lastClassName = ""; } } void CodeGen::visit(AssignExpr& node) { node.value->accept(*this); llvm::Value* val = lastValue; std::string valClassName = lastClassName; // Capture type of value if (!val) return; if (namedValues.find(node.name) == namedValues.end()) { std::cerr << "Unknown variable name: " << node.name << std::endl; lastValue = nullptr; return; } llvm::AllocaInst* alloca = namedValues[node.name]; builder->CreateStore(val, alloca); lastValue = val; // Update type tracking if (!valClassName.empty()) { varTypes[node.name] = valClassName; lastClassName = valClassName; } } void CodeGen::visit(CallExpr& node) { // Handle "readcsv" specially if (node.callee == "readcsv") { if (node.args.size() < 1 || node.args.size() > 3) { std::cerr << "readcsv() takes 1 to 3 arguments." << std::endl; lastValue = nullptr; return; } node.args[0]->accept(*this); llvm::Value* filenameVal = lastValue; llvm::Value* sepVal; if (node.args.size() >= 2) { node.args[1]->accept(*this); sepVal = lastValue; } else { sepVal = builder->CreateGlobalStringPtr(","); } llvm::Value* quoteVal; if (node.args.size() >= 3) { node.args[2]->accept(*this); quoteVal = lastValue; } else { quoteVal = builder->CreateGlobalStringPtr("\""); } llvm::Function* readCsvFn = module->getFunction("sun_read_csv"); lastValue = builder->CreateCall(readCsvFn, {filenameVal, sepVal, quoteVal}, "csv_data"); lastClassName = "CSVArray"; return; } // Handle "len" specially if (node.callee == "len") { if (node.args.size() != 1) { std::cerr << "len() takes exactly 1 argument." << std::endl; lastValue = nullptr; return; } node.args[0]->accept(*this); llvm::Value* arrVal = lastValue; // Cast to i8* if needed (it should be i8* already as void*) if (arrVal->getType() != llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)) { arrVal = builder->CreateBitCast(arrVal, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); } llvm::Function* lenFn = module->getFunction("sun_array_length"); lastValue = builder->CreateCall(lenFn, {arrVal}, "len"); lastClassName = ""; // Returns int return; } // Handle "print" specially or as a normal function if (node.callee == "print") { if (node.args.size() != 1) { std::cerr << "print() takes exactly 1 argument." << std::endl; lastValue = nullptr; return; } node.args[0]->accept(*this); llvm::Value* argVal = lastValue; std::string argType = lastClassName; if (argType == "String") { llvm::Function* printStr = module->getFunction("print_string"); builder->CreateCall(printStr, {argVal}); } else { // Assume int llvm::Function* printInt = module->getFunction("print_int"); builder->CreateCall(printInt, {argVal}); } lastValue = nullptr; // print returns void return; } llvm::Function* calleeF = module->getFunction(node.callee); if (!calleeF) { std::cerr << "Unknown function referenced: " << node.callee << std::endl; lastValue = nullptr; return; } if (calleeF->arg_size() != node.args.size()) { std::cerr << "Incorrect # arguments passed." << std::endl; lastValue = nullptr; return; } std::vector argsV; for (unsigned i = 0, e = node.args.size(); i != e; ++i) { node.args[i]->accept(*this); if (!lastValue) return; argsV.push_back(lastValue); } lastValue = builder->CreateCall(calleeF, argsV, "calltmp"); // We don't know the return type class name easily without function signatures in symbol table. // For now, assume int. lastClassName = ""; } void CodeGen::visit(BinaryExpr& node) { node.left->accept(*this); llvm::Value* L = lastValue; std::string typeL = lastClassName; node.right->accept(*this); llvm::Value* R = lastValue; std::string typeR = lastClassName; if (!L || !R) { lastValue = nullptr; return; } // String Concatenation if (node.op == "+") { bool isStrL = (typeL == "String"); bool isStrR = (typeR == "String"); if (isStrL || isStrR) { llvm::Value* strL = L; llvm::Value* strR = R; if (!isStrL) { // Convert int to string llvm::Function* intToStr = module->getFunction("int_to_str"); strL = builder->CreateCall(intToStr, {L}, "l_str"); } if (!isStrR) { // Convert int to string llvm::Function* intToStr = module->getFunction("int_to_str"); strR = builder->CreateCall(intToStr, {R}, "r_str"); } llvm::Function* strConcat = module->getFunction("str_concat"); lastValue = builder->CreateCall(strConcat, {strL, strR}, "concat"); lastClassName = "String"; return; } } // Check if we are doing pointer arithmetic or struct comparison (not supported yet) if (!L->getType()->isIntegerTy() || !R->getType()->isIntegerTy()) { std::cerr << "Binary operations only supported for integers." << std::endl; lastValue = nullptr; return; } if (node.op == "+") { lastValue = builder->CreateAdd(L, R, "addtmp"); } else if (node.op == "-") { lastValue = builder->CreateSub(L, R, "subtmp"); } else if (node.op == "*") { lastValue = builder->CreateMul(L, R, "multmp"); } else if (node.op == "/") { lastValue = builder->CreateSDiv(L, R, "divtmp"); } else if (node.op == "<") { lastValue = builder->CreateICmpSLT(L, R, "cmptmp"); // Convert i1 to i32 for now, as everything is i32 lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); } else if (node.op == ">") { lastValue = builder->CreateICmpSGT(L, R, "cmptmp"); lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); } else if (node.op == "==") { lastValue = builder->CreateICmpEQ(L, R, "cmptmp"); lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); } else { std::cerr << "Unknown operator: " << node.op << std::endl; lastValue = nullptr; } } void CodeGen::visit(ReturnStmt& node) { if (node.value) { node.value->accept(*this); if (lastValue) { builder->CreateRet(lastValue); } } else { builder->CreateRetVoid(); } } void CodeGen::visit(VarDeclStmt& node) { llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); llvm::Value* initVal; std::string initClassName = ""; if (node.initializer) { node.initializer->accept(*this); initVal = lastValue; initClassName = lastClassName; } else { initVal = llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)); } llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, node.name, initVal->getType()); builder->CreateStore(initVal, alloca); namedValues[node.name] = alloca; if (!initClassName.empty()) { varTypes[node.name] = initClassName; } } void CodeGen::visit(BlockStmt& node) { for (const auto& stmt : node.statements) { stmt->accept(*this); } } void CodeGen::visit(IfStmt& node) { node.condition->accept(*this); llvm::Value* condV = lastValue; if (!condV) return; // Convert condition to bool (i1) by comparing not equal to 0 condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "ifcond"); llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(*context, "then", theFunction); llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(*context, "else"); llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "ifcont"); builder->CreateCondBr(condV, thenBB, elseBB); // Emit then value. builder->SetInsertPoint(thenBB); node.thenBranch->accept(*this); // Only branch to merge if the block isn't already terminated (e.g. by return) if (!builder->GetInsertBlock()->getTerminator()) { builder->CreateBr(mergeBB); } // Emit else block. theFunction->insert(theFunction->end(), elseBB); builder->SetInsertPoint(elseBB); if (node.elseBranch) { node.elseBranch->accept(*this); } if (!builder->GetInsertBlock()->getTerminator()) { builder->CreateBr(mergeBB); } // Emit merge block. theFunction->insert(theFunction->end(), mergeBB); builder->SetInsertPoint(mergeBB); } void CodeGen::visit(WhileStmt& node) { llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction); llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody"); llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter"); // Jump to condition builder->CreateBr(condBB); // Condition Block builder->SetInsertPoint(condBB); node.condition->accept(*this); llvm::Value* condV = lastValue; if (!condV) return; condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond"); builder->CreateCondBr(condV, bodyBB, afterBB); // Body Block theFunction->insert(theFunction->end(), bodyBB); builder->SetInsertPoint(bodyBB); breakStack.push_back(afterBB); node.body->accept(*this); breakStack.pop_back(); builder->CreateBr(condBB); // Loop back to condition // After Block theFunction->insert(theFunction->end(), afterBB); builder->SetInsertPoint(afterBB); } void CodeGen::visit(ForStmt& node) { llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); // Emit init if (node.init) { node.init->accept(*this); } llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction); llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody"); llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter"); // Jump to condition builder->CreateBr(condBB); // Condition Block builder->SetInsertPoint(condBB); llvm::Value* condV = nullptr; if (node.condition) { node.condition->accept(*this); condV = lastValue; if (!condV) return; condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond"); } else { // Infinite loop if no condition condV = llvm::ConstantInt::get(*context, llvm::APInt(1, 1, true)); } builder->CreateCondBr(condV, bodyBB, afterBB); // Body Block theFunction->insert(theFunction->end(), bodyBB); builder->SetInsertPoint(bodyBB); breakStack.push_back(afterBB); node.body->accept(*this); breakStack.pop_back(); // Increment if (node.increment) { node.increment->accept(*this); } builder->CreateBr(condBB); // Loop back to condition // After Block theFunction->insert(theFunction->end(), afterBB); builder->SetInsertPoint(afterBB); } void CodeGen::visit(ForInStmt& node) { // 1. Evaluate collection (array) node.collection->accept(*this); llvm::Value* arrPtr = lastValue; std::string arrayType = lastClassName; if (!arrPtr) return; // 2. Get Array Length llvm::Function* lenFn = module->getFunction("sun_array_length"); llvm::Value* lenVal = builder->CreateCall(lenFn, {arrPtr}, "len"); // 3. Create Loop Variable (index) llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); llvm::AllocaInst* indexAlloca = createEntryBlockAlloca(theFunction, "index_" + node.variableName, llvm::Type::getInt32Ty(*context)); builder->CreateStore(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)), indexAlloca); // 4. Create Loop Blocks llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction); llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody"); llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter"); builder->CreateBr(condBB); // 5. Condition: index < length builder->SetInsertPoint(condBB); llvm::Value* indexVal = builder->CreateLoad(llvm::Type::getInt32Ty(*context), indexAlloca, "index"); llvm::Value* condV = builder->CreateICmpSLT(indexVal, lenVal, "loopcond"); builder->CreateCondBr(condV, bodyBB, afterBB); // 6. Body theFunction->insert(theFunction->end(), bodyBB); builder->SetInsertPoint(bodyBB); // Fetch element at index llvm::Function* getFn = module->getFunction("sun_array_get"); llvm::Value* elemPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem"); llvm::Value* elemVal; std::string elemClassName; if (arrayType == "StringArray") { elemVal = elemPtr; // Already i8* elemClassName = "String"; } else { // Assume IntArray elemVal = builder->CreatePtrToInt(elemPtr, llvm::Type::getInt32Ty(*context), "elemInt"); elemClassName = ""; } // Create variable for element // Check if variable already exists (shadowing?) or create new scope? // For simplicity, create new alloca in entry block (or reuse if exists) // But we are inside a loop, so we should probably create it once in entry block. // However, we need to update it every iteration. llvm::AllocaInst* varAlloca; if (namedValues.find(node.variableName) == namedValues.end()) { varAlloca = createEntryBlockAlloca(theFunction, node.variableName, elemVal->getType()); namedValues[node.variableName] = varAlloca; } else { varAlloca = namedValues[node.variableName]; } builder->CreateStore(elemVal, varAlloca); // Update type tracking if (!elemClassName.empty()) { varTypes[node.variableName] = elemClassName; } // Execute body breakStack.push_back(afterBB); node.body->accept(*this); breakStack.pop_back(); // Increment index llvm::Value* nextIndex = builder->CreateAdd(indexVal, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "nextindex"); builder->CreateStore(nextIndex, indexAlloca); builder->CreateBr(condBB); // After Block theFunction->insert(theFunction->end(), afterBB); builder->SetInsertPoint(afterBB); } void CodeGen::visit(ExpressionStmt& node) { // Check if it's a print call (hack for now, since we don't have function calls as expressions yet) // Actually, we don't have CallExpr yet. // But we can check if the expression is a special "print" node? // No, let's assume print is a function call. // But we don't have CallExpr. // Let's implement CallExpr? Or just handle it in Parser as a statement? // The user asked for print() function. // If I implement CallExpr, I can call any function. // For now, let's assume print is handled via a special AST node or just CallExpr. // I haven't implemented CallExpr yet. // Let's implement CallExpr in AST. node.expression->accept(*this); } void CodeGen::visit(FunctionDef& node) { // Save current namedValues (which might be main's locals) auto oldNamedValues = namedValues; // 1. Define Function Type (Int32 -> Int32, Int32...) std::vector ints(node.args.size(), llvm::Type::getInt32Ty(*context)); llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), ints, false); // 2. Create Function llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, node.name, module.get()); // 3. Set Argument Names unsigned idx = 0; for (auto& arg : f->args()) { arg.setName(node.args[idx++]); } // 4. Create Entry Block llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", f); builder->SetInsertPoint(bb); // 5. Record Arguments in Symbol Table (Allocas) namedValues.clear(); for (auto& arg : f->args()) { // Create an alloca for this variable. llvm::AllocaInst* alloca = createEntryBlockAlloca(f, std::string(arg.getName()), arg.getType()); // Store the initial value into the alloca. builder->CreateStore(&arg, alloca); // Add arguments to variable symbol table. namedValues[std::string(arg.getName())] = alloca; } // 6. Generate Body node.body->accept(*this); // 7. Verify Function llvm::verifyFunction(*f); // Restore namedValues namedValues = oldNamedValues; } void CodeGen::visit(ClassDef& node) { // 1. Create Struct Type std::vector fieldTypes(node.fields.size(), llvm::Type::getInt32Ty(*context)); // All fields are i32 for now llvm::StructType* structType = llvm::StructType::create(*context, fieldTypes, node.name); classStructs[node.name] = structType; classFields[node.name] = node.fields; } void CodeGen::visit(NewExpr& node) { if (classStructs.find(node.className) == classStructs.end()) { std::cerr << "Unknown class: " << node.className << std::endl; lastValue = nullptr; lastClassName = ""; return; } llvm::StructType* structType = classStructs[node.className]; // Allocate memory for the object (on heap ideally, but stack for now or malloc) // For simplicity, let's use malloc via LLVM IR or just alloca if we want stack allocation. // Let's use alloca for now (stack allocated objects). llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, "new_" + node.className, structType); lastValue = alloca; // The "object" is a pointer to the struct lastClassName = node.className; } void CodeGen::visit(GetPropExpr& node) { // 1. Evaluate object expression node.object->accept(*this); llvm::Value* objectPtr = lastValue; std::string className = lastClassName; if (!objectPtr) return; if (className.empty()) { std::cerr << "Cannot determine class type for property access." << std::endl; lastValue = nullptr; return; } if (classFields.find(className) == classFields.end()) { std::cerr << "Unknown class type: " << className << std::endl; lastValue = nullptr; return; } const auto& fields = classFields[className]; int fieldIndex = -1; for (size_t i = 0; i < fields.size(); ++i) { if (fields[i] == node.name) { fieldIndex = i; break; } } if (fieldIndex == -1) { std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl; lastValue = nullptr; return; } // 3. Generate GEP and Load std::vector indices; indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); // Dereference pointer indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); // Field index llvm::Type* structType = classStructs[className]; llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr"); lastValue = builder->CreateLoad(llvm::Type::getInt32Ty(*context), fieldPtr, "fieldval"); lastClassName = ""; // Field is int } void CodeGen::visit(SetPropExpr& node) { // 1. Evaluate object node.object->accept(*this); llvm::Value* objectPtr = lastValue; std::string className = lastClassName; if (!objectPtr) return; // 2. Evaluate value node.value->accept(*this); llvm::Value* val = lastValue; if (className.empty()) { std::cerr << "Cannot determine class type for property set." << std::endl; return; } const auto& fields = classFields[className]; int fieldIndex = -1; for (size_t i = 0; i < fields.size(); ++i) { if (fields[i] == node.name) { fieldIndex = i; break; } } if (fieldIndex == -1) { std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl; return; } std::vector indices; indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); llvm::Type* structType = classStructs[className]; llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr"); builder->CreateStore(val, fieldPtr); lastValue = val; } void CodeGen::visit(ArrayExpr& node) { int size = node.elements.size(); llvm::Function* createFn = module->getFunction("sun_array_create"); llvm::Function* setFn = module->getFunction("sun_array_set"); llvm::Value* sizeVal = llvm::ConstantInt::get(*context, llvm::APInt(32, size)); llvm::Value* arrPtr = builder->CreateCall(createFn, {sizeVal}, "array"); std::string elemType = "IntArray"; // Default to IntArray for (int i = 0; i < size; ++i) { node.elements[i]->accept(*this); llvm::Value* val = lastValue; if (lastClassName == "String") { elemType = "StringArray"; } // Cast value to i8* (void*) llvm::Value* voidVal; if (val->getType()->isPointerTy()) { voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); } else { // Assume int (i32), cast to pointer sized int then inttoptr? // Or just inttoptr directly. voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); } llvm::Value* indexVal = llvm::ConstantInt::get(*context, llvm::APInt(32, i)); builder->CreateCall(setFn, {arrPtr, indexVal, voidVal}); } lastValue = arrPtr; lastClassName = elemType; } void CodeGen::visit(IndexExpr& node) { node.array->accept(*this); llvm::Value* arrPtr = lastValue; std::string arrayType = lastClassName; node.index->accept(*this); llvm::Value* indexVal = lastValue; llvm::Function* getFn = module->getFunction("sun_array_get"); llvm::Value* valPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem"); if (arrayType == "StringArray") { lastValue = valPtr; // Already i8* lastClassName = "String"; } else if (arrayType == "CSVArray") { lastValue = valPtr; // It's a void* (pointer to array) lastClassName = "StringArray"; } else { // Assume IntArray, cast back to i32 lastValue = builder->CreatePtrToInt(valPtr, llvm::Type::getInt32Ty(*context), "elemInt"); lastClassName = ""; } } void CodeGen::visit(ArrayAssignExpr& node) { node.array->accept(*this); llvm::Value* arrPtr = lastValue; node.index->accept(*this); llvm::Value* indexVal = lastValue; node.value->accept(*this); llvm::Value* val = lastValue; llvm::Function* setFn = module->getFunction("sun_array_set"); // Cast value to i8* llvm::Value* voidVal; if (val->getType()->isPointerTy()) { voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); } else { voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); } builder->CreateCall(setFn, {arrPtr, indexVal, voidVal}); lastValue = val; } void CodeGen::visit(SwitchStmt& node) { // 1. Evaluate condition node.condition->accept(*this); llvm::Value* condVal = lastValue; if (!condVal) return; llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); // 2. Create Merge Block (after switch) llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "switchmerge"); // Push mergeBB to breakStack so 'break' jumps here breakStack.push_back(mergeBB); // 3. Create Default Block llvm::BasicBlock* defaultBB = llvm::BasicBlock::Create(*context, "switchdefault"); // 4. Create Blocks for Cases std::vector caseBBs; for (size_t i = 0; i < node.cases.size(); ++i) { caseBBs.push_back(llvm::BasicBlock::Create(*context, "case" + std::to_string(i))); } // 5. Generate Comparisons (If-Else Chain) llvm::BasicBlock* currentTestBB = builder->GetInsertBlock(); for (size_t i = 0; i < node.cases.size(); ++i) { // Evaluate case value node.cases[i].value->accept(*this); llvm::Value* caseVal = lastValue; // Compare llvm::Value* cmp = builder->CreateICmpEQ(condVal, caseVal, "casecmp"); llvm::BasicBlock* nextTestBB; if (i < node.cases.size() - 1) { nextTestBB = llvm::BasicBlock::Create(*context, "nexttest", theFunction); } else { nextTestBB = defaultBB; } builder->CreateCondBr(cmp, caseBBs[i], nextTestBB); if (i < node.cases.size() - 1) { builder->SetInsertPoint(nextTestBB); currentTestBB = nextTestBB; } } if (node.cases.empty()) { builder->CreateBr(defaultBB); } // 6. Generate Case Bodies for (size_t i = 0; i < node.cases.size(); ++i) { theFunction->insert(theFunction->end(), caseBBs[i]); builder->SetInsertPoint(caseBBs[i]); node.cases[i].body->accept(*this); // Fallthrough if (!builder->GetInsertBlock()->getTerminator()) { if (i < node.cases.size() - 1) { builder->CreateBr(caseBBs[i+1]); } else { builder->CreateBr(defaultBB); } } } // 7. Generate Default Body theFunction->insert(theFunction->end(), defaultBB); builder->SetInsertPoint(defaultBB); if (node.defaultCase) { node.defaultCase->accept(*this); } if (!builder->GetInsertBlock()->getTerminator()) { builder->CreateBr(mergeBB); } // 8. Finish breakStack.pop_back(); theFunction->insert(theFunction->end(), mergeBB); builder->SetInsertPoint(mergeBB); } void CodeGen::visit(BreakStmt& node) { if (breakStack.empty()) { std::cerr << "Error: break statement outside loop or switch" << std::endl; return; } builder->CreateBr(breakStack.back()); llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); llvm::BasicBlock* deadBB = llvm::BasicBlock::Create(*context, "dead"); theFunction->insert(theFunction->end(), deadBB); builder->SetInsertPoint(deadBB); }