diff --git a/build/ast.o b/build/ast.o index c17c079..64be2af 100644 Binary files a/build/ast.o and b/build/ast.o differ diff --git a/build/ast_printer.o b/build/ast_printer.o index 4181c35..d464cb0 100644 Binary files a/build/ast_printer.o and b/build/ast_printer.o differ diff --git a/build/codegen.o b/build/codegen.o index 1f3e5a5..324067e 100644 Binary files a/build/codegen.o and b/build/codegen.o differ diff --git a/build/lexer.o b/build/lexer.o index 0569422..3d34873 100644 Binary files a/build/lexer.o and b/build/lexer.o differ diff --git a/build/parser.o b/build/parser.o index c4bde76..5ee9ebc 100644 Binary files a/build/parser.o and b/build/parser.o differ diff --git a/src/ast.cpp b/src/ast.cpp index 2d6cfaf..79944f4 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -11,6 +11,7 @@ void SetPropExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } void ArrayExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } void IndexExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } void ArrayAssignExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void IncrementExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } void BinaryExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } void ReturnStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } void VarDeclStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } diff --git a/src/ast.h b/src/ast.h index be613dd..68f61c6 100644 --- a/src/ast.h +++ b/src/ast.h @@ -108,6 +108,13 @@ public: void accept(ASTVisitor& visitor) override; }; +class IncrementExpr : public Expr { +public: + std::unique_ptr variable; // Can be VariableExpr, GetPropExpr, IndexExpr + IncrementExpr(std::unique_ptr variable) : variable(std::move(variable)) {} + void accept(ASTVisitor& visitor) override; +}; + class BinaryExpr : public Expr { public: std::unique_ptr left; @@ -239,6 +246,7 @@ public: virtual void visit(ArrayExpr& node) = 0; virtual void visit(IndexExpr& node) = 0; virtual void visit(ArrayAssignExpr& node) = 0; + virtual void visit(IncrementExpr& node) = 0; virtual void visit(BinaryExpr& node) = 0; virtual void visit(ReturnStmt& node) = 0; virtual void visit(VarDeclStmt& node) = 0; diff --git a/src/ast_printer.cpp b/src/ast_printer.cpp index ecf8842..18524a4 100644 --- a/src/ast_printer.cpp +++ b/src/ast_printer.cpp @@ -130,3 +130,60 @@ void ASTPrinter::visit(ClassDef& node) { } std::cout << "}"; } + +void ASTPrinter::visit(ArrayExpr& node) { + std::cout << "["; + for (size_t i = 0; i < node.elements.size(); ++i) { + node.elements[i]->accept(*this); + if (i < node.elements.size() - 1) std::cout << ", "; + } + std::cout << "]"; +} + +void ASTPrinter::visit(IndexExpr& node) { + node.array->accept(*this); + std::cout << "["; + node.index->accept(*this); + std::cout << "]"; +} + +void ASTPrinter::visit(ArrayAssignExpr& node) { + node.array->accept(*this); + std::cout << "["; + node.index->accept(*this); + std::cout << "] = "; + node.value->accept(*this); +} + +void ASTPrinter::visit(IncrementExpr& node) { + node.variable->accept(*this); + std::cout << "++"; +} + +void ASTPrinter::visit(ForInStmt& node) { + std::cout << "for (var " << node.variableName << " in "; + node.collection->accept(*this); + std::cout << ") "; + node.body->accept(*this); +} + +void ASTPrinter::visit(SwitchStmt& node) { + std::cout << "switch ("; + node.condition->accept(*this); + std::cout << ") {" << std::endl; + for (const auto& c : node.cases) { + std::cout << "case "; + c.value->accept(*this); + std::cout << ": "; + c.body->accept(*this); + } + if (node.defaultCase) { + std::cout << "default: "; + node.defaultCase->accept(*this); + } + std::cout << "}"; +} + +void ASTPrinter::visit(BreakStmt& node) { + std::cout << "break;"; +} diff --git a/src/ast_printer.h b/src/ast_printer.h index 90f861f..048697f 100644 --- a/src/ast_printer.h +++ b/src/ast_printer.h @@ -16,6 +16,10 @@ public: void visit(NewExpr& node) override; void visit(GetPropExpr& node) override; void visit(SetPropExpr& node) override; + void visit(ArrayExpr& node) override; + void visit(IndexExpr& node) override; + void visit(ArrayAssignExpr& node) override; + void visit(IncrementExpr& node) override; void visit(BinaryExpr& node) override; void visit(ReturnStmt& node) override; void visit(VarDeclStmt& node) override; @@ -23,6 +27,9 @@ public: void visit(IfStmt& node) override; void visit(WhileStmt& node) override; void visit(ForStmt& node) override; + void visit(ForInStmt& node) override; + void visit(SwitchStmt& node) override; + void visit(BreakStmt& node) override; void visit(ExpressionStmt& node) override; void visit(FunctionDef& node) override; void visit(ClassDef& node) override; diff --git a/src/codegen.cpp b/src/codegen.cpp index 2a85935..9b4726a 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -248,6 +248,87 @@ void CodeGen::visit(CallExpr& node) { lastClassName = ""; } +void CodeGen::visit(IncrementExpr& node) { + if (auto varExpr = dynamic_cast(node.variable.get())) { + if (namedValues.find(varExpr->name) == namedValues.end()) { + std::cerr << "Unknown variable: " << varExpr->name << std::endl; + lastValue = nullptr; + return; + } + llvm::AllocaInst* alloca = namedValues[varExpr->name]; + llvm::Value* val = builder->CreateLoad(alloca->getAllocatedType(), alloca, varExpr->name.c_str()); + lastValue = val; // Return old value + + llvm::Value* inc = builder->CreateAdd(val, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "inc"); + builder->CreateStore(inc, alloca); + + } else if (auto getProp = dynamic_cast(node.variable.get())) { + getProp->object->accept(*this); + llvm::Value* objectPtr = lastValue; + std::string className = lastClassName; + + if (!objectPtr) return; + + if (classFields.find(className) == classFields.end()) { + std::cerr << "Unknown class type: " << className << std::endl; + lastValue = nullptr; + return; + } + + const auto& fields = classFields[className]; + int fieldIndex = -1; + for (size_t i = 0; i < fields.size(); ++i) { + if (fields[i] == getProp->name) { + fieldIndex = i; + break; + } + } + + if (fieldIndex == -1) { + std::cerr << "Unknown field: " << getProp->name << " in class " << className << std::endl; + lastValue = nullptr; + return; + } + + std::vector indices; + indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); + indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); + + llvm::Type* structType = classStructs[className]; + llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr"); + + llvm::Value* val = builder->CreateLoad(llvm::Type::getInt32Ty(*context), fieldPtr, "fieldval"); + lastValue = val; + + llvm::Value* inc = builder->CreateAdd(val, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "inc"); + builder->CreateStore(inc, fieldPtr); + + } else if (auto indexExpr = dynamic_cast(node.variable.get())) { + indexExpr->array->accept(*this); + llvm::Value* arrPtr = lastValue; + + indexExpr->index->accept(*this); + llvm::Value* indexVal = lastValue; + + llvm::Function* getFn = module->getFunction("sun_array_get"); + llvm::Function* setFn = module->getFunction("sun_array_set"); + + llvm::Value* valPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem"); + // Assume IntArray for increment + llvm::Value* val = builder->CreatePtrToInt(valPtr, llvm::Type::getInt32Ty(*context), "elemInt"); + lastValue = val; + + llvm::Value* inc = builder->CreateAdd(val, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "inc"); + + // Store back + llvm::Value* voidVal = builder->CreateIntToPtr(inc, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); + builder->CreateCall(setFn, {arrPtr, indexVal, voidVal}); + } else { + std::cerr << "Invalid increment target." << std::endl; + lastValue = nullptr; + } +} + void CodeGen::visit(BinaryExpr& node) { node.left->accept(*this); llvm::Value* L = lastValue; @@ -311,6 +392,12 @@ void CodeGen::visit(BinaryExpr& node) { } else if (node.op == ">") { lastValue = builder->CreateICmpSGT(L, R, "cmptmp"); lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); + } else if (node.op == "<=") { + lastValue = builder->CreateICmpSLE(L, R, "cmptmp"); + lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); + } else if (node.op == ">=") { + lastValue = builder->CreateICmpSGE(L, R, "cmptmp"); + lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); } else if (node.op == "==") { lastValue = builder->CreateICmpEQ(L, R, "cmptmp"); lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); diff --git a/src/codegen.h b/src/codegen.h index dbcfb82..8b0d284 100644 --- a/src/codegen.h +++ b/src/codegen.h @@ -28,6 +28,7 @@ public: void visit(ArrayExpr& node) override; void visit(IndexExpr& node) override; void visit(ArrayAssignExpr& node) override; + void visit(IncrementExpr& node) override; void visit(BinaryExpr& node) override; void visit(ReturnStmt& node) override; void visit(VarDeclStmt& node) override; diff --git a/src/lexer.cpp b/src/lexer.cpp index 138aa56..c493717 100644 --- a/src/lexer.cpp +++ b/src/lexer.cpp @@ -107,7 +107,12 @@ Token Lexer::scanToken() { case '.': return makeToken(TokenType::DOT, "."); case ':': return makeToken(TokenType::COLON, ":"); case ';': return makeToken(TokenType::SEMICOLON, ";"); - case '+': return makeToken(TokenType::PLUS, "+"); + case '+': + if (peek() == '+') { + advance(); + return makeToken(TokenType::PLUS_PLUS, "++"); + } + return makeToken(TokenType::PLUS, "+"); case '-': return makeToken(TokenType::MINUS, "-"); case '*': return makeToken(TokenType::STAR, "*"); case '/': @@ -122,8 +127,18 @@ Token Lexer::scanToken() { return makeToken(TokenType::EQUAL_EQUAL, "=="); } return makeToken(TokenType::EQUALS, "="); - case '<': return makeToken(TokenType::LESS, "<"); - case '>': return makeToken(TokenType::GREATER, ">"); + case '<': + if (peek() == '=') { + advance(); + return makeToken(TokenType::LESS_EQUAL, "<="); + } + return makeToken(TokenType::LESS, "<"); + case '>': + if (peek() == '=') { + advance(); + return makeToken(TokenType::GREATER_EQUAL, ">="); + } + return makeToken(TokenType::GREATER, ">"); case '"': return string(); default: return makeToken(TokenType::UNKNOWN, std::string(1, c)); } diff --git a/src/parser.cpp b/src/parser.cpp index 44bfaf1..dedaacb 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -366,7 +366,7 @@ std::unique_ptr Parser::equality() { std::unique_ptr Parser::comparison() { std::unique_ptr expr = term(); - while (match(TokenType::LESS) || match(TokenType::GREATER)) { + while (match(TokenType::LESS) || match(TokenType::GREATER) || match(TokenType::LESS_EQUAL) || match(TokenType::GREATER_EQUAL)) { std::string op = previous().value; std::unique_ptr right = term(); expr = std::make_unique(std::move(expr), op, std::move(right)); @@ -410,6 +410,8 @@ std::unique_ptr Parser::call() { std::unique_ptr index = expression(); consume(TokenType::RBRACKET, "Expect ']' after index."); expr = std::make_unique(std::move(expr), std::move(index)); + } else if (match(TokenType::PLUS_PLUS)) { + expr = std::make_unique(std::move(expr)); } else { break; } diff --git a/src/token.h b/src/token.h index d869176..e5346a8 100644 --- a/src/token.h +++ b/src/token.h @@ -41,6 +41,9 @@ enum class TokenType { EQUAL_EQUAL,// == LESS, // < GREATER, // > + LESS_EQUAL, // <= + GREATER_EQUAL, // >= + PLUS_PLUS, // ++ END_OF_FILE, UNKNOWN }; diff --git a/src/version.h b/src/version.h index a8ec32d..e0dc157 100644 --- a/src/version.h +++ b/src/version.h @@ -1,6 +1,6 @@ #ifndef SUN_VERSION_H #define SUN_VERSION_H -#define SUN_VERSION "0.3.1" +#define SUN_VERSION "0.4.0" #endif // SUN_VERSION_H diff --git a/sun b/sun index 6af4459..f720881 100755 Binary files a/sun and b/sun differ diff --git a/tests/test_for.sun b/tests/test_for.sun index 7f57a4a..d6b1697 100644 --- a/tests/test_for.sun +++ b/tests/test_for.sun @@ -1,7 +1,4 @@ -function main() { - print("Counting:"); - for (var i = 0; i < 5; i = i + 1) { - print(i); - } - return 0; +print("Counting:"); +for (var i = 0; i <= 5; i++) { + print(i); } \ No newline at end of file