add support do i++ and >= and <=

This commit is contained in:
Talles Amadeu 2025-12-09 14:37:28 -03:00
parent 05b3ff6e1e
commit 31d7ac51ed
17 changed files with 189 additions and 11 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -11,6 +11,7 @@ void SetPropExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void ArrayExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void IndexExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void ArrayAssignExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void IncrementExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void BinaryExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void ReturnStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); }
void VarDeclStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); }

View File

@ -108,6 +108,13 @@ public:
void accept(ASTVisitor& visitor) override;
};
class IncrementExpr : public Expr {
public:
std::unique_ptr<Expr> variable; // Can be VariableExpr, GetPropExpr, IndexExpr
IncrementExpr(std::unique_ptr<Expr> variable) : variable(std::move(variable)) {}
void accept(ASTVisitor& visitor) override;
};
class BinaryExpr : public Expr {
public:
std::unique_ptr<Expr> left;
@ -239,6 +246,7 @@ public:
virtual void visit(ArrayExpr& node) = 0;
virtual void visit(IndexExpr& node) = 0;
virtual void visit(ArrayAssignExpr& node) = 0;
virtual void visit(IncrementExpr& node) = 0;
virtual void visit(BinaryExpr& node) = 0;
virtual void visit(ReturnStmt& node) = 0;
virtual void visit(VarDeclStmt& node) = 0;

View File

@ -130,3 +130,60 @@ void ASTPrinter::visit(ClassDef& node) {
}
std::cout << "}";
}
void ASTPrinter::visit(ArrayExpr& node) {
std::cout << "[";
for (size_t i = 0; i < node.elements.size(); ++i) {
node.elements[i]->accept(*this);
if (i < node.elements.size() - 1) std::cout << ", ";
}
std::cout << "]";
}
void ASTPrinter::visit(IndexExpr& node) {
node.array->accept(*this);
std::cout << "[";
node.index->accept(*this);
std::cout << "]";
}
void ASTPrinter::visit(ArrayAssignExpr& node) {
node.array->accept(*this);
std::cout << "[";
node.index->accept(*this);
std::cout << "] = ";
node.value->accept(*this);
}
void ASTPrinter::visit(IncrementExpr& node) {
node.variable->accept(*this);
std::cout << "++";
}
void ASTPrinter::visit(ForInStmt& node) {
std::cout << "for (var " << node.variableName << " in ";
node.collection->accept(*this);
std::cout << ") ";
node.body->accept(*this);
}
void ASTPrinter::visit(SwitchStmt& node) {
std::cout << "switch (";
node.condition->accept(*this);
std::cout << ") {" << std::endl;
for (const auto& c : node.cases) {
std::cout << "case ";
c.value->accept(*this);
std::cout << ": ";
c.body->accept(*this);
}
if (node.defaultCase) {
std::cout << "default: ";
node.defaultCase->accept(*this);
}
std::cout << "}";
}
void ASTPrinter::visit(BreakStmt& node) {
std::cout << "break;";
}

View File

@ -16,6 +16,10 @@ public:
void visit(NewExpr& node) override;
void visit(GetPropExpr& node) override;
void visit(SetPropExpr& node) override;
void visit(ArrayExpr& node) override;
void visit(IndexExpr& node) override;
void visit(ArrayAssignExpr& node) override;
void visit(IncrementExpr& node) override;
void visit(BinaryExpr& node) override;
void visit(ReturnStmt& node) override;
void visit(VarDeclStmt& node) override;
@ -23,6 +27,9 @@ public:
void visit(IfStmt& node) override;
void visit(WhileStmt& node) override;
void visit(ForStmt& node) override;
void visit(ForInStmt& node) override;
void visit(SwitchStmt& node) override;
void visit(BreakStmt& node) override;
void visit(ExpressionStmt& node) override;
void visit(FunctionDef& node) override;
void visit(ClassDef& node) override;

View File

@ -248,6 +248,87 @@ void CodeGen::visit(CallExpr& node) {
lastClassName = "";
}
void CodeGen::visit(IncrementExpr& node) {
if (auto varExpr = dynamic_cast<VariableExpr*>(node.variable.get())) {
if (namedValues.find(varExpr->name) == namedValues.end()) {
std::cerr << "Unknown variable: " << varExpr->name << std::endl;
lastValue = nullptr;
return;
}
llvm::AllocaInst* alloca = namedValues[varExpr->name];
llvm::Value* val = builder->CreateLoad(alloca->getAllocatedType(), alloca, varExpr->name.c_str());
lastValue = val; // Return old value
llvm::Value* inc = builder->CreateAdd(val, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "inc");
builder->CreateStore(inc, alloca);
} else if (auto getProp = dynamic_cast<GetPropExpr*>(node.variable.get())) {
getProp->object->accept(*this);
llvm::Value* objectPtr = lastValue;
std::string className = lastClassName;
if (!objectPtr) return;
if (classFields.find(className) == classFields.end()) {
std::cerr << "Unknown class type: " << className << std::endl;
lastValue = nullptr;
return;
}
const auto& fields = classFields[className];
int fieldIndex = -1;
for (size_t i = 0; i < fields.size(); ++i) {
if (fields[i] == getProp->name) {
fieldIndex = i;
break;
}
}
if (fieldIndex == -1) {
std::cerr << "Unknown field: " << getProp->name << " in class " << className << std::endl;
lastValue = nullptr;
return;
}
std::vector<llvm::Value*> indices;
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)));
indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex)));
llvm::Type* structType = classStructs[className];
llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr");
llvm::Value* val = builder->CreateLoad(llvm::Type::getInt32Ty(*context), fieldPtr, "fieldval");
lastValue = val;
llvm::Value* inc = builder->CreateAdd(val, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "inc");
builder->CreateStore(inc, fieldPtr);
} else if (auto indexExpr = dynamic_cast<IndexExpr*>(node.variable.get())) {
indexExpr->array->accept(*this);
llvm::Value* arrPtr = lastValue;
indexExpr->index->accept(*this);
llvm::Value* indexVal = lastValue;
llvm::Function* getFn = module->getFunction("sun_array_get");
llvm::Function* setFn = module->getFunction("sun_array_set");
llvm::Value* valPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem");
// Assume IntArray for increment
llvm::Value* val = builder->CreatePtrToInt(valPtr, llvm::Type::getInt32Ty(*context), "elemInt");
lastValue = val;
llvm::Value* inc = builder->CreateAdd(val, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "inc");
// Store back
llvm::Value* voidVal = builder->CreateIntToPtr(inc, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0));
builder->CreateCall(setFn, {arrPtr, indexVal, voidVal});
} else {
std::cerr << "Invalid increment target." << std::endl;
lastValue = nullptr;
}
}
void CodeGen::visit(BinaryExpr& node) {
node.left->accept(*this);
llvm::Value* L = lastValue;
@ -311,6 +392,12 @@ void CodeGen::visit(BinaryExpr& node) {
} else if (node.op == ">") {
lastValue = builder->CreateICmpSGT(L, R, "cmptmp");
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
} else if (node.op == "<=") {
lastValue = builder->CreateICmpSLE(L, R, "cmptmp");
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
} else if (node.op == ">=") {
lastValue = builder->CreateICmpSGE(L, R, "cmptmp");
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");
} else if (node.op == "==") {
lastValue = builder->CreateICmpEQ(L, R, "cmptmp");
lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp");

View File

@ -28,6 +28,7 @@ public:
void visit(ArrayExpr& node) override;
void visit(IndexExpr& node) override;
void visit(ArrayAssignExpr& node) override;
void visit(IncrementExpr& node) override;
void visit(BinaryExpr& node) override;
void visit(ReturnStmt& node) override;
void visit(VarDeclStmt& node) override;

View File

@ -107,7 +107,12 @@ Token Lexer::scanToken() {
case '.': return makeToken(TokenType::DOT, ".");
case ':': return makeToken(TokenType::COLON, ":");
case ';': return makeToken(TokenType::SEMICOLON, ";");
case '+': return makeToken(TokenType::PLUS, "+");
case '+':
if (peek() == '+') {
advance();
return makeToken(TokenType::PLUS_PLUS, "++");
}
return makeToken(TokenType::PLUS, "+");
case '-': return makeToken(TokenType::MINUS, "-");
case '*': return makeToken(TokenType::STAR, "*");
case '/':
@ -122,8 +127,18 @@ Token Lexer::scanToken() {
return makeToken(TokenType::EQUAL_EQUAL, "==");
}
return makeToken(TokenType::EQUALS, "=");
case '<': return makeToken(TokenType::LESS, "<");
case '>': return makeToken(TokenType::GREATER, ">");
case '<':
if (peek() == '=') {
advance();
return makeToken(TokenType::LESS_EQUAL, "<=");
}
return makeToken(TokenType::LESS, "<");
case '>':
if (peek() == '=') {
advance();
return makeToken(TokenType::GREATER_EQUAL, ">=");
}
return makeToken(TokenType::GREATER, ">");
case '"': return string();
default: return makeToken(TokenType::UNKNOWN, std::string(1, c));
}

View File

@ -366,7 +366,7 @@ std::unique_ptr<Expr> Parser::equality() {
std::unique_ptr<Expr> Parser::comparison() {
std::unique_ptr<Expr> expr = term();
while (match(TokenType::LESS) || match(TokenType::GREATER)) {
while (match(TokenType::LESS) || match(TokenType::GREATER) || match(TokenType::LESS_EQUAL) || match(TokenType::GREATER_EQUAL)) {
std::string op = previous().value;
std::unique_ptr<Expr> right = term();
expr = std::make_unique<BinaryExpr>(std::move(expr), op, std::move(right));
@ -410,6 +410,8 @@ std::unique_ptr<Expr> Parser::call() {
std::unique_ptr<Expr> index = expression();
consume(TokenType::RBRACKET, "Expect ']' after index.");
expr = std::make_unique<IndexExpr>(std::move(expr), std::move(index));
} else if (match(TokenType::PLUS_PLUS)) {
expr = std::make_unique<IncrementExpr>(std::move(expr));
} else {
break;
}

View File

@ -41,6 +41,9 @@ enum class TokenType {
EQUAL_EQUAL,// ==
LESS, // <
GREATER, // >
LESS_EQUAL, // <=
GREATER_EQUAL, // >=
PLUS_PLUS, // ++
END_OF_FILE,
UNKNOWN
};

View File

@ -1,6 +1,6 @@
#ifndef SUN_VERSION_H
#define SUN_VERSION_H
#define SUN_VERSION "0.3.1"
#define SUN_VERSION "0.4.0"
#endif // SUN_VERSION_H

BIN
sun

Binary file not shown.

View File

@ -1,7 +1,4 @@
function main() {
print("Counting:");
for (var i = 0; i < 5; i = i + 1) {
print(i);
}
return 0;
print("Counting:");
for (var i = 0; i <= 5; i++) {
print(i);
}