commit f8695b14e334c84a048403a16f487a34ec2ed63d Author: Talles Amadeu Date: Tue Dec 9 13:23:45 2025 -0300 first version of Sun diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3c1d723 --- /dev/null +++ b/Makefile @@ -0,0 +1,52 @@ +CXX = clang++ + +# Detect OS +UNAME_S := $(shell uname -s) + +# Default LLVM_CONFIG +LLVM_CONFIG ?= llvm-config + +# If llvm-config is not in PATH, try to guess location based on OS +ifeq ($(shell which $(LLVM_CONFIG) 2>/dev/null),) + ifeq ($(UNAME_S),Darwin) + # macOS Homebrew path + LLVM_CONFIG = /opt/homebrew/opt/llvm/bin/llvm-config + endif +endif + +# Verify if llvm-config exists +ifeq ($(shell which $(LLVM_CONFIG) 2>/dev/null),) + # If still not found, try common versioned names on Linux + ifneq ($(UNAME_S),Darwin) + LLVM_CONFIG := $(shell which llvm-config-18 2>/dev/null || which llvm-config-17 2>/dev/null || which llvm-config-16 2>/dev/null || which llvm-config-15 2>/dev/null) + endif +endif + +# Final check +ifeq ($(LLVM_CONFIG),) + $(error "llvm-config not found. Please install LLVM or set LLVM_CONFIG manually (e.g., make LLVM_CONFIG=llvm-config-15)") +endif + +CXXFLAGS = -std=c++17 -Wall -Wextra -Wno-unused-parameter $(shell $(LLVM_CONFIG) --cxxflags) +LDFLAGS = $(shell $(LLVM_CONFIG) --ldflags --system-libs --libs core) + +SRC_DIR = src +BUILD_DIR = build +TARGET = sun + +SRCS = $(wildcard $(SRC_DIR)/*.cpp) +OBJS = $(patsubst $(SRC_DIR)/%.cpp, $(BUILD_DIR)/%.o, $(SRCS)) + +all: $(TARGET) + +$(TARGET): $(OBJS) + $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) + +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp + @mkdir -p $(BUILD_DIR) + $(CXX) $(CXXFLAGS) -c -o $@ $< + +clean: + rm -rf $(BUILD_DIR) $(TARGET) + +.PHONY: all clean diff --git a/README.md b/README.md new file mode 100644 index 0000000..83147da --- /dev/null +++ b/README.md @@ -0,0 +1,58 @@ +# Sun Language Compiler + +## Instalação Rápida (Recomendado) + +O projeto inclui um script de instalação que: +1. Verifica e instala dependências (no Ubuntu/Debian). +2. Compila o projeto. +3. Instala o binário `sun` em `/usr/local/bin`. + +Basta rodar: + +```bash +./install.sh +``` + +## Instalação Manual + +### Pré-requisitos + +#### macOS +- Homebrew instalado +- LLVM instalado via Homebrew: + ```bash + brew install llvm@21 + ``` + +#### Ubuntu / Linux (Debian-based) +*O script `install.sh` tenta instalar estes pacotes automaticamente.* +- Build essentials e LLVM/Clang: + ```bash + sudo apt update + sudo apt install build-essential llvm clang + ``` + *Nota: O projeto requer suporte a C++17. Versões recentes do LLVM (15+) são recomendadas.* + +### Compilando +1. Abra o terminal na pasta do projeto. +2. Rode o comando: + ```bash + make + ``` + *Se o `llvm-config` não estiver no PATH ou tiver um nome diferente (ex: `llvm-config-15`), você pode especificá-lo:* + ```bash + make LLVM_CONFIG=llvm-config-15 + ``` +3. Isso irá gerar o executável `sun`. + +### Usando o Compilador +Para compilar e rodar um arquivo `.sun`: +```bash +./sun arquivo.sun +``` + +Para apenas compilar e gerar um executável: +```bash +./sun arquivo.sun -o nome_do_programa +./nome_do_programa +``` diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..6128c88 --- /dev/null +++ b/install.sh @@ -0,0 +1,118 @@ +#!/bin/bash + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo -e "${GREEN}=== Sun Language Installer ===${NC}" + +# Function to check command existence +check_cmd() { + command -v "$1" &> /dev/null +} + +# 1. Check for build tools and install if missing (Ubuntu/Debian only) +echo "Checking dependencies..." + +if ! check_cmd make || ! check_cmd clang++; then + if [ -f /etc/debian_version ]; then + echo -e "${YELLOW}Dependencies missing. Attempting to install...${NC}" + echo "Running: sudo apt update && sudo apt install -y build-essential llvm clang" + sudo apt update && sudo apt install -y build-essential llvm clang + else + if ! check_cmd make; then echo -e "${RED}Error: 'make' is not installed.${NC}"; exit 1; fi + if ! check_cmd clang++; then echo -e "${RED}Error: 'clang++' is not installed.${NC}"; exit 1; fi + fi +fi + +# 2. Detect LLVM +echo "Detecting LLVM..." +LLVM_CONFIG="" + +# Try standard llvm-config +if command -v llvm-config &> /dev/null; then + LLVM_CONFIG="llvm-config" +fi + +# Try macOS Homebrew path +if [ -z "$LLVM_CONFIG" ] && [ -f "/opt/homebrew/opt/llvm/bin/llvm-config" ]; then + LLVM_CONFIG="/opt/homebrew/opt/llvm/bin/llvm-config" +fi + +# Try versioned names (common on Linux) +if [ -z "$LLVM_CONFIG" ]; then + for ver in 21 20 19 18 17 16 15; do + if command -v "llvm-config-$ver" &> /dev/null; then + LLVM_CONFIG="llvm-config-$ver" + break + fi + done +fi + +if [ -z "$LLVM_CONFIG" ]; then + echo -e "${RED}Error: LLVM not found.${NC}" + + if [ -f /etc/debian_version ]; then + echo -e "${YELLOW}Attempting to install LLVM...${NC}" + sudo apt install -y llvm + + # Try to find it again + if command -v llvm-config &> /dev/null; then + LLVM_CONFIG="llvm-config" + else + # Try versioned names again + for ver in 21 20 19 18 17 16 15; do + if command -v "llvm-config-$ver" &> /dev/null; then + LLVM_CONFIG="llvm-config-$ver" + break + fi + done + fi + fi +fi + +if [ -z "$LLVM_CONFIG" ]; then + echo -e "${RED}Error: LLVM could not be found or installed.${NC}" + echo "Please install LLVM manually:" + echo " - macOS: brew install llvm@21" + echo " - Ubuntu: sudo apt install llvm clang" + exit 1 +fi + +echo "Using LLVM config: $LLVM_CONFIG" + +# 3. Build +echo "Building Sun..." +# Pass LLVM_CONFIG to make +if ! make LLVM_CONFIG="$LLVM_CONFIG"; then + echo -e "${RED}Build failed.${NC}" + exit 1 +fi + +# 4. Install +INSTALL_DIR="/usr/local/bin" + +# Check if /usr/local/bin exists, if not create it +if [ ! -d "$INSTALL_DIR" ]; then + echo "Creating $INSTALL_DIR..." + sudo mkdir -p "$INSTALL_DIR" +fi + +echo "Installing 'sun' binary to $INSTALL_DIR..." + +if [ -w "$INSTALL_DIR" ]; then + cp sun "$INSTALL_DIR/sun" +else + echo -e "${YELLOW}Permission denied. Trying with sudo...${NC}" + if sudo cp sun "$INSTALL_DIR/sun"; then + echo "Copied successfully with sudo." + else + echo -e "${RED}Failed to install. Please run with sudo or check permissions.${NC}" + exit 1 + fi +fi + +echo -e "${GREEN}Success! Sun has been installed.${NC}" +echo "You can now run 'sun' from anywhere." diff --git a/src/ast.cpp b/src/ast.cpp new file mode 100644 index 0000000..2d6cfaf --- /dev/null +++ b/src/ast.cpp @@ -0,0 +1,26 @@ +#include "ast.h" + +void NumberExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void StringExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void VariableExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void AssignExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void CallExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void NewExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void GetPropExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void SetPropExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ArrayExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void IndexExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ArrayAssignExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void BinaryExpr::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ReturnStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void VarDeclStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void BlockStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void IfStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void WhileStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ForStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ForInStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void SwitchStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void BreakStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ExpressionStmt::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void FunctionDef::accept(ASTVisitor& visitor) { visitor.visit(*this); } +void ClassDef::accept(ASTVisitor& visitor) { visitor.visit(*this); } diff --git a/src/ast.h b/src/ast.h new file mode 100644 index 0000000..be613dd --- /dev/null +++ b/src/ast.h @@ -0,0 +1,257 @@ +#ifndef AST_H +#define AST_H + +#include +#include +#include + +// Forward declarations +class ASTVisitor; + +class ASTNode { +public: + virtual ~ASTNode() = default; + virtual void accept(ASTVisitor& visitor) = 0; +}; + +class Expr : public ASTNode {}; +class Stmt : public ASTNode {}; + +class NumberExpr : public Expr { +public: + int value; + NumberExpr(int value) : value(value) {} + void accept(ASTVisitor& visitor) override; +}; + +class StringExpr : public Expr { +public: + std::string value; + StringExpr(const std::string& value) : value(value) {} + void accept(ASTVisitor& visitor) override; +}; + +class VariableExpr : public Expr { +public: + std::string name; + VariableExpr(const std::string& name) : name(name) {} + void accept(ASTVisitor& visitor) override; +}; + +class AssignExpr : public Expr { +public: + std::string name; + std::unique_ptr value; + AssignExpr(const std::string& name, std::unique_ptr value) + : name(name), value(std::move(value)) {} + void accept(ASTVisitor& visitor) override; +}; + +class CallExpr : public Expr { +public: + std::string callee; + std::vector> args; + CallExpr(const std::string& callee, std::vector> args) + : callee(callee), args(std::move(args)) {} + void accept(ASTVisitor& visitor) override; +}; + +class NewExpr : public Expr { +public: + std::string className; + NewExpr(const std::string& className) : className(className) {} + void accept(ASTVisitor& visitor) override; +}; + +class GetPropExpr : public Expr { +public: + std::unique_ptr object; + std::string name; + GetPropExpr(std::unique_ptr object, const std::string& name) + : object(std::move(object)), name(name) {} + void accept(ASTVisitor& visitor) override; +}; + +class SetPropExpr : public Expr { +public: + std::unique_ptr object; + std::string name; + std::unique_ptr value; + SetPropExpr(std::unique_ptr object, const std::string& name, std::unique_ptr value) + : object(std::move(object)), name(name), value(std::move(value)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ArrayExpr : public Expr { +public: + std::vector> elements; + ArrayExpr(std::vector> elements) : elements(std::move(elements)) {} + void accept(ASTVisitor& visitor) override; +}; + +class IndexExpr : public Expr { +public: + std::unique_ptr array; + std::unique_ptr index; + IndexExpr(std::unique_ptr array, std::unique_ptr index) + : array(std::move(array)), index(std::move(index)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ArrayAssignExpr : public Expr { +public: + std::unique_ptr array; + std::unique_ptr index; + std::unique_ptr value; + ArrayAssignExpr(std::unique_ptr array, std::unique_ptr index, std::unique_ptr value) + : array(std::move(array)), index(std::move(index)), value(std::move(value)) {} + void accept(ASTVisitor& visitor) override; +}; + +class BinaryExpr : public Expr { +public: + std::unique_ptr left; + std::string op; + std::unique_ptr right; + BinaryExpr(std::unique_ptr left, std::string op, std::unique_ptr right) + : left(std::move(left)), op(op), right(std::move(right)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ReturnStmt : public Stmt { +public: + std::unique_ptr value; + ReturnStmt(std::unique_ptr value) : value(std::move(value)) {} + void accept(ASTVisitor& visitor) override; +}; + +class BreakStmt : public Stmt { +public: + void accept(ASTVisitor& visitor) override; +}; + +class VarDeclStmt : public Stmt { +public: + std::string name; + std::unique_ptr initializer; + VarDeclStmt(const std::string& name, std::unique_ptr initializer) + : name(name), initializer(std::move(initializer)) {} + void accept(ASTVisitor& visitor) override; +}; + +class BlockStmt : public Stmt { +public: + std::vector> statements; + void accept(ASTVisitor& visitor) override; +}; + +class IfStmt : public Stmt { +public: + std::unique_ptr condition; + std::unique_ptr thenBranch; + std::unique_ptr elseBranch; + IfStmt(std::unique_ptr condition, std::unique_ptr thenBranch, std::unique_ptr elseBranch) + : condition(std::move(condition)), thenBranch(std::move(thenBranch)), elseBranch(std::move(elseBranch)) {} + void accept(ASTVisitor& visitor) override; +}; + +class WhileStmt : public Stmt { +public: + std::unique_ptr condition; + std::unique_ptr body; + WhileStmt(std::unique_ptr condition, std::unique_ptr body) + : condition(std::move(condition)), body(std::move(body)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ForStmt : public Stmt { +public: + std::unique_ptr init; + std::unique_ptr condition; + std::unique_ptr increment; + std::unique_ptr body; + ForStmt(std::unique_ptr init, std::unique_ptr condition, std::unique_ptr increment, std::unique_ptr body) + : init(std::move(init)), condition(std::move(condition)), increment(std::move(increment)), body(std::move(body)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ForInStmt : public Stmt { +public: + std::string variableName; + std::unique_ptr collection; + std::unique_ptr body; + ForInStmt(const std::string& variableName, std::unique_ptr collection, std::unique_ptr body) + : variableName(variableName), collection(std::move(collection)), body(std::move(body)) {} + void accept(ASTVisitor& visitor) override; +}; + +struct Case { + std::unique_ptr value; + std::unique_ptr body; +}; + +class SwitchStmt : public Stmt { +public: + std::unique_ptr condition; + std::vector cases; + std::unique_ptr defaultCase; + SwitchStmt(std::unique_ptr condition, std::vector cases, std::unique_ptr defaultCase) + : condition(std::move(condition)), cases(std::move(cases)), defaultCase(std::move(defaultCase)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ExpressionStmt : public Stmt { +public: + std::unique_ptr expression; + ExpressionStmt(std::unique_ptr expression) : expression(std::move(expression)) {} + void accept(ASTVisitor& visitor) override; +}; + +class FunctionDef : public ASTNode { +public: + std::string name; + std::vector args; + std::unique_ptr body; + FunctionDef(const std::string& name, std::vector args, std::unique_ptr body) + : name(name), args(args), body(std::move(body)) {} + void accept(ASTVisitor& visitor) override; +}; + +class ClassDef : public ASTNode { +public: + std::string name; + std::vector fields; + ClassDef(const std::string& name, std::vector fields) + : name(name), fields(fields) {} + void accept(ASTVisitor& visitor) override; +}; + +class ASTVisitor { +public: + virtual void visit(NumberExpr& node) = 0; + virtual void visit(StringExpr& node) = 0; + virtual void visit(VariableExpr& node) = 0; + virtual void visit(AssignExpr& node) = 0; + virtual void visit(CallExpr& node) = 0; + virtual void visit(NewExpr& node) = 0; + virtual void visit(GetPropExpr& node) = 0; + virtual void visit(SetPropExpr& node) = 0; + virtual void visit(ArrayExpr& node) = 0; + virtual void visit(IndexExpr& node) = 0; + virtual void visit(ArrayAssignExpr& node) = 0; + virtual void visit(BinaryExpr& node) = 0; + virtual void visit(ReturnStmt& node) = 0; + virtual void visit(VarDeclStmt& node) = 0; + virtual void visit(BlockStmt& node) = 0; + virtual void visit(IfStmt& node) = 0; + virtual void visit(WhileStmt& node) = 0; + virtual void visit(ForStmt& node) = 0; + virtual void visit(ForInStmt& node) = 0; + virtual void visit(SwitchStmt& node) = 0; + virtual void visit(BreakStmt& node) = 0; + virtual void visit(ExpressionStmt& node) = 0; + virtual void visit(FunctionDef& node) = 0; + virtual void visit(ClassDef& node) = 0; +}; + +#endif // AST_H diff --git a/src/ast_extension.h b/src/ast_extension.h new file mode 100644 index 0000000..4e6bc58 --- /dev/null +++ b/src/ast_extension.h @@ -0,0 +1,6 @@ +class ExpressionStmt : public Stmt { +public: + std::unique_ptr expression; + ExpressionStmt(std::unique_ptr expression) : expression(std::move(expression)) {} + void accept(ASTVisitor& visitor) override; +}; diff --git a/src/ast_printer.cpp b/src/ast_printer.cpp new file mode 100644 index 0000000..ecf8842 --- /dev/null +++ b/src/ast_printer.cpp @@ -0,0 +1,132 @@ +#include "ast_printer.h" + +void ASTPrinter::print(ASTNode& node) { + node.accept(*this); + std::cout << std::endl; +} + +void ASTPrinter::visit(NumberExpr& node) { + std::cout << node.value; +} + +void ASTPrinter::visit(StringExpr& node) { + std::cout << "\"" << node.value << "\""; +} + +void ASTPrinter::visit(VariableExpr& node) { + std::cout << node.name; +} + +void ASTPrinter::visit(AssignExpr& node) { + std::cout << "(" << node.name << " = "; + node.value->accept(*this); + std::cout << ")"; +} + +void ASTPrinter::visit(CallExpr& node) { + std::cout << node.callee << "("; + for (size_t i = 0; i < node.args.size(); ++i) { + node.args[i]->accept(*this); + if (i < node.args.size() - 1) std::cout << ", "; + } + std::cout << ")"; +} + +void ASTPrinter::visit(BinaryExpr& node) { + std::cout << "("; + node.left->accept(*this); + std::cout << " " << node.op << " "; + node.right->accept(*this); + std::cout << ")"; +} + +void ASTPrinter::visit(ReturnStmt& node) { + std::cout << "return "; + if (node.value) { + node.value->accept(*this); + } + std::cout << ";"; +} + +void ASTPrinter::visit(VarDeclStmt& node) { + std::cout << "var " << node.name << " = "; + node.initializer->accept(*this); + std::cout << ";"; +} + +void ASTPrinter::visit(BlockStmt& node) { + std::cout << " {" << std::endl; + for (const auto& stmt : node.statements) { + std::cout << " "; + stmt->accept(*this); + std::cout << std::endl; + } + std::cout << "}"; +} + +void ASTPrinter::visit(IfStmt& node) { + std::cout << "if ("; + node.condition->accept(*this); + std::cout << ") "; + node.thenBranch->accept(*this); + if (node.elseBranch) { + std::cout << " else "; + node.elseBranch->accept(*this); + } +} + +void ASTPrinter::visit(WhileStmt& node) { + std::cout << "while ("; + node.condition->accept(*this); + std::cout << ") "; + node.body->accept(*this); +} + +void ASTPrinter::visit(ForStmt& node) { + std::cout << "for ("; + if (node.init) node.init->accept(*this); + std::cout << "; "; + if (node.condition) node.condition->accept(*this); + std::cout << "; "; + if (node.increment) node.increment->accept(*this); + std::cout << ") "; + node.body->accept(*this); +} + +void ASTPrinter::visit(ExpressionStmt& node) { + node.expression->accept(*this); + std::cout << ";"; +} + +void ASTPrinter::visit(FunctionDef& node) { + std::cout << "function " << node.name << "("; + for (size_t i = 0; i < node.args.size(); ++i) { + std::cout << node.args[i]; + if (i < node.args.size() - 1) std::cout << ", "; + } + std::cout << ")"; + node.body->accept(*this); +} + +void ASTPrinter::visit(NewExpr& node) { + std::cout << "new " << node.className << "()"; +} + +void ASTPrinter::visit(GetPropExpr& node) { + node.object->accept(*this); + std::cout << "." << node.name; +} + +void ASTPrinter::visit(SetPropExpr& node) { + node.object->accept(*this); + std::cout << "." << node.name << " = "; + node.value->accept(*this); +} + +void ASTPrinter::visit(ClassDef& node) { + std::cout << "class " << node.name << " {" << std::endl; + for (const auto& field : node.fields) { + std::cout << " var " << field << ";" << std::endl; + } + std::cout << "}"; +} diff --git a/src/ast_printer.h b/src/ast_printer.h new file mode 100644 index 0000000..90f861f --- /dev/null +++ b/src/ast_printer.h @@ -0,0 +1,31 @@ +#ifndef AST_PRINTER_H +#define AST_PRINTER_H + +#include "ast.h" +#include + +class ASTPrinter : public ASTVisitor { +public: + void print(ASTNode& node); + + void visit(NumberExpr& node) override; + void visit(StringExpr& node) override; + void visit(VariableExpr& node) override; + void visit(AssignExpr& node) override; + void visit(CallExpr& node) override; + void visit(NewExpr& node) override; + void visit(GetPropExpr& node) override; + void visit(SetPropExpr& node) override; + void visit(BinaryExpr& node) override; + void visit(ReturnStmt& node) override; + void visit(VarDeclStmt& node) override; + void visit(BlockStmt& node) override; + void visit(IfStmt& node) override; + void visit(WhileStmt& node) override; + void visit(ForStmt& node) override; + void visit(ExpressionStmt& node) override; + void visit(FunctionDef& node) override; + void visit(ClassDef& node) override; +}; + +#endif // AST_PRINTER_H diff --git a/src/codegen.cpp b/src/codegen.cpp new file mode 100644 index 0000000..103eeec --- /dev/null +++ b/src/codegen.cpp @@ -0,0 +1,864 @@ +#include "codegen.h" +#include + +CodeGen::CodeGen() { + context = std::make_unique(); + module = std::make_unique("sun_module", *context); + builder = std::make_unique>(*context); + + // Declare runtime functions + // void print_int(int) + llvm::FunctionType* printIntType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::Type::getInt32Ty(*context)}, false); + llvm::Function::Create(printIntType, llvm::Function::ExternalLinkage, "print_int", module.get()); + + // void print_string(char*) + llvm::FunctionType* printStrType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); + llvm::Function::Create(printStrType, llvm::Function::ExternalLinkage, "print_string", module.get()); + + // char* int_to_str(int) + llvm::FunctionType* intToStrType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false); + llvm::Function::Create(intToStrType, llvm::Function::ExternalLinkage, "int_to_str", module.get()); + + // char* str_concat(char*, char*) + llvm::FunctionType* strConcatType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); + llvm::Function::Create(strConcatType, llvm::Function::ExternalLinkage, "str_concat", module.get()); + + // void* sun_array_create(int size) + llvm::FunctionType* arrayCreateType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::Type::getInt32Ty(*context)}, false); + llvm::Function::Create(arrayCreateType, llvm::Function::ExternalLinkage, "sun_array_create", module.get()); + + // void sun_array_set(void* arr, int index, void* value) + llvm::FunctionType* arraySetType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context), llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); + llvm::Function::Create(arraySetType, llvm::Function::ExternalLinkage, "sun_array_set", module.get()); + + // void* sun_array_get(void* arr, int index) + llvm::FunctionType* arrayGetType = llvm::FunctionType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0), llvm::Type::getInt32Ty(*context)}, false); + llvm::Function::Create(arrayGetType, llvm::Function::ExternalLinkage, "sun_array_get", module.get()); + + // int sun_array_length(void* arr) + llvm::FunctionType* arrayLenType = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), {llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)}, false); + llvm::Function::Create(arrayLenType, llvm::Function::ExternalLinkage, "sun_array_length", module.get()); +} + +llvm::AllocaInst* CodeGen::createEntryBlockAlloca(llvm::Function* theFunction, const std::string& varName, llvm::Type* type) { + llvm::IRBuilder<> tmpBuilder(&theFunction->getEntryBlock(), theFunction->getEntryBlock().begin()); + return tmpBuilder.CreateAlloca(type, nullptr, varName); +} + +void CodeGen::generate(ASTNode& node) { + if (dynamic_cast(&node) || dynamic_cast(&node)) { + if (mainFunction && builder->GetInsertBlock()) { + mainInsertBlock = builder->GetInsertBlock(); + } + node.accept(*this); + } else { + if (!mainFunction) { + llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), false); + mainFunction = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "main", module.get()); + llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", mainFunction); + builder->SetInsertPoint(bb); + } else { + if (mainInsertBlock) { + builder->SetInsertPoint(mainInsertBlock); + } + } + node.accept(*this); + mainInsertBlock = builder->GetInsertBlock(); + } +} + +void CodeGen::finish() { + if (mainFunction) { + if (mainInsertBlock && !mainInsertBlock->getTerminator()) { + builder->SetInsertPoint(mainInsertBlock); + builder->CreateRet(llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true))); + } + } +} + +void CodeGen::print() { + module->print(llvm::outs(), nullptr); +} + +void CodeGen::visit(NumberExpr& node) { + lastValue = llvm::ConstantInt::get(*context, llvm::APInt(32, node.value, true)); + lastClassName = ""; // Not a class +} + +void CodeGen::visit(StringExpr& node) { + lastValue = builder->CreateGlobalString(node.value); + lastClassName = "String"; // Special marker for string +} + +void CodeGen::visit(VariableExpr& node) { + if (namedValues.find(node.name) == namedValues.end()) { + std::cerr << "Unknown variable name: " << node.name << std::endl; + lastValue = nullptr; + lastClassName = ""; + return; + } + // Load the value from the alloca + llvm::AllocaInst* alloca = namedValues[node.name]; + lastValue = builder->CreateLoad(alloca->getAllocatedType(), alloca, node.name.c_str()); + + if (varTypes.find(node.name) != varTypes.end()) { + lastClassName = varTypes[node.name]; + } else { + lastClassName = ""; + } +} + +void CodeGen::visit(AssignExpr& node) { + node.value->accept(*this); + llvm::Value* val = lastValue; + std::string valClassName = lastClassName; // Capture type of value + + if (!val) return; + + if (namedValues.find(node.name) == namedValues.end()) { + std::cerr << "Unknown variable name: " << node.name << std::endl; + lastValue = nullptr; + return; + } + + llvm::AllocaInst* alloca = namedValues[node.name]; + builder->CreateStore(val, alloca); + lastValue = val; + + // Update type tracking + if (!valClassName.empty()) { + varTypes[node.name] = valClassName; + lastClassName = valClassName; + } +} + +void CodeGen::visit(CallExpr& node) { + // Handle "print" specially or as a normal function + if (node.callee == "print") { + if (node.args.size() != 1) { + std::cerr << "print() takes exactly 1 argument." << std::endl; + lastValue = nullptr; + return; + } + node.args[0]->accept(*this); + llvm::Value* argVal = lastValue; + std::string argType = lastClassName; + + if (argType == "String") { + llvm::Function* printStr = module->getFunction("print_string"); + builder->CreateCall(printStr, {argVal}); + } else { + // Assume int + llvm::Function* printInt = module->getFunction("print_int"); + builder->CreateCall(printInt, {argVal}); + } + lastValue = nullptr; // print returns void + return; + } + + llvm::Function* calleeF = module->getFunction(node.callee); + if (!calleeF) { + std::cerr << "Unknown function referenced: " << node.callee << std::endl; + lastValue = nullptr; + return; + } + + if (calleeF->arg_size() != node.args.size()) { + std::cerr << "Incorrect # arguments passed." << std::endl; + lastValue = nullptr; + return; + } + + std::vector argsV; + for (unsigned i = 0, e = node.args.size(); i != e; ++i) { + node.args[i]->accept(*this); + if (!lastValue) return; + argsV.push_back(lastValue); + } + + lastValue = builder->CreateCall(calleeF, argsV, "calltmp"); + // We don't know the return type class name easily without function signatures in symbol table. + // For now, assume int. + lastClassName = ""; +} + +void CodeGen::visit(BinaryExpr& node) { + node.left->accept(*this); + llvm::Value* L = lastValue; + std::string typeL = lastClassName; + + node.right->accept(*this); + llvm::Value* R = lastValue; + std::string typeR = lastClassName; + + if (!L || !R) { + lastValue = nullptr; + return; + } + + // String Concatenation + if (node.op == "+") { + bool isStrL = (typeL == "String"); + bool isStrR = (typeR == "String"); + + if (isStrL || isStrR) { + llvm::Value* strL = L; + llvm::Value* strR = R; + + if (!isStrL) { + // Convert int to string + llvm::Function* intToStr = module->getFunction("int_to_str"); + strL = builder->CreateCall(intToStr, {L}, "l_str"); + } + if (!isStrR) { + // Convert int to string + llvm::Function* intToStr = module->getFunction("int_to_str"); + strR = builder->CreateCall(intToStr, {R}, "r_str"); + } + + llvm::Function* strConcat = module->getFunction("str_concat"); + lastValue = builder->CreateCall(strConcat, {strL, strR}, "concat"); + lastClassName = "String"; + return; + } + } + + // Check if we are doing pointer arithmetic or struct comparison (not supported yet) + if (!L->getType()->isIntegerTy() || !R->getType()->isIntegerTy()) { + std::cerr << "Binary operations only supported for integers." << std::endl; + lastValue = nullptr; + return; + } + + if (node.op == "+") { + lastValue = builder->CreateAdd(L, R, "addtmp"); + } else if (node.op == "-") { + lastValue = builder->CreateSub(L, R, "subtmp"); + } else if (node.op == "*") { + lastValue = builder->CreateMul(L, R, "multmp"); + } else if (node.op == "/") { + lastValue = builder->CreateSDiv(L, R, "divtmp"); + } else if (node.op == "<") { + lastValue = builder->CreateICmpSLT(L, R, "cmptmp"); + // Convert i1 to i32 for now, as everything is i32 + lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); + } else if (node.op == ">") { + lastValue = builder->CreateICmpSGT(L, R, "cmptmp"); + lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); + } else if (node.op == "==") { + lastValue = builder->CreateICmpEQ(L, R, "cmptmp"); + lastValue = builder->CreateIntCast(lastValue, llvm::Type::getInt32Ty(*context), true, "booltmp"); + } else { + std::cerr << "Unknown operator: " << node.op << std::endl; + lastValue = nullptr; + } +} + +void CodeGen::visit(ReturnStmt& node) { + if (node.value) { + node.value->accept(*this); + if (lastValue) { + builder->CreateRet(lastValue); + } + } else { + builder->CreateRetVoid(); + } +} + +void CodeGen::visit(VarDeclStmt& node) { + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + llvm::Value* initVal; + std::string initClassName = ""; + + if (node.initializer) { + node.initializer->accept(*this); + initVal = lastValue; + initClassName = lastClassName; + } else { + initVal = llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)); + } + + llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, node.name, initVal->getType()); + builder->CreateStore(initVal, alloca); + namedValues[node.name] = alloca; + + if (!initClassName.empty()) { + varTypes[node.name] = initClassName; + } +} + +void CodeGen::visit(BlockStmt& node) { + for (const auto& stmt : node.statements) { + stmt->accept(*this); + } +} + +void CodeGen::visit(IfStmt& node) { + node.condition->accept(*this); + llvm::Value* condV = lastValue; + if (!condV) return; + + // Convert condition to bool (i1) by comparing not equal to 0 + condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "ifcond"); + + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + + llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(*context, "then", theFunction); + llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(*context, "else"); + llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "ifcont"); + + builder->CreateCondBr(condV, thenBB, elseBB); + + // Emit then value. + builder->SetInsertPoint(thenBB); + node.thenBranch->accept(*this); + + // Only branch to merge if the block isn't already terminated (e.g. by return) + if (!builder->GetInsertBlock()->getTerminator()) { + builder->CreateBr(mergeBB); + } + + // Emit else block. + theFunction->insert(theFunction->end(), elseBB); + builder->SetInsertPoint(elseBB); + if (node.elseBranch) { + node.elseBranch->accept(*this); + } + + if (!builder->GetInsertBlock()->getTerminator()) { + builder->CreateBr(mergeBB); + } + + // Emit merge block. + theFunction->insert(theFunction->end(), mergeBB); + builder->SetInsertPoint(mergeBB); +} + +void CodeGen::visit(WhileStmt& node) { + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + + llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction); + llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody"); + llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter"); + + // Jump to condition + builder->CreateBr(condBB); + + // Condition Block + builder->SetInsertPoint(condBB); + node.condition->accept(*this); + llvm::Value* condV = lastValue; + if (!condV) return; + + condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond"); + builder->CreateCondBr(condV, bodyBB, afterBB); + + // Body Block + theFunction->insert(theFunction->end(), bodyBB); + builder->SetInsertPoint(bodyBB); + + breakStack.push_back(afterBB); + node.body->accept(*this); + breakStack.pop_back(); + + builder->CreateBr(condBB); // Loop back to condition + + // After Block + theFunction->insert(theFunction->end(), afterBB); + builder->SetInsertPoint(afterBB); +} + +void CodeGen::visit(ForStmt& node) { + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + + // Emit init + if (node.init) { + node.init->accept(*this); + } + + llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction); + llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody"); + llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter"); + + // Jump to condition + builder->CreateBr(condBB); + + // Condition Block + builder->SetInsertPoint(condBB); + + llvm::Value* condV = nullptr; + if (node.condition) { + node.condition->accept(*this); + condV = lastValue; + if (!condV) return; + condV = builder->CreateICmpNE(condV, llvm::ConstantInt::get(*context, llvm::APInt(32, 0, true)), "loopcond"); + } else { + // Infinite loop if no condition + condV = llvm::ConstantInt::get(*context, llvm::APInt(1, 1, true)); + } + + builder->CreateCondBr(condV, bodyBB, afterBB); + + // Body Block + theFunction->insert(theFunction->end(), bodyBB); + builder->SetInsertPoint(bodyBB); + + breakStack.push_back(afterBB); + node.body->accept(*this); + breakStack.pop_back(); + + // Increment + if (node.increment) { + node.increment->accept(*this); + } + + builder->CreateBr(condBB); // Loop back to condition + + // After Block + theFunction->insert(theFunction->end(), afterBB); + builder->SetInsertPoint(afterBB); +} + +void CodeGen::visit(ForInStmt& node) { + // 1. Evaluate collection (array) + node.collection->accept(*this); + llvm::Value* arrPtr = lastValue; + std::string arrayType = lastClassName; + + if (!arrPtr) return; + + // 2. Get Array Length + llvm::Function* lenFn = module->getFunction("sun_array_length"); + llvm::Value* lenVal = builder->CreateCall(lenFn, {arrPtr}, "len"); + + // 3. Create Loop Variable (index) + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + llvm::AllocaInst* indexAlloca = createEntryBlockAlloca(theFunction, "index_" + node.variableName, llvm::Type::getInt32Ty(*context)); + builder->CreateStore(llvm::ConstantInt::get(*context, llvm::APInt(32, 0)), indexAlloca); + + // 4. Create Loop Blocks + llvm::BasicBlock* condBB = llvm::BasicBlock::Create(*context, "loopcond", theFunction); + llvm::BasicBlock* bodyBB = llvm::BasicBlock::Create(*context, "loopbody"); + llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(*context, "loopafter"); + + builder->CreateBr(condBB); + + // 5. Condition: index < length + builder->SetInsertPoint(condBB); + llvm::Value* indexVal = builder->CreateLoad(llvm::Type::getInt32Ty(*context), indexAlloca, "index"); + llvm::Value* condV = builder->CreateICmpSLT(indexVal, lenVal, "loopcond"); + builder->CreateCondBr(condV, bodyBB, afterBB); + + // 6. Body + theFunction->insert(theFunction->end(), bodyBB); + builder->SetInsertPoint(bodyBB); + + // Fetch element at index + llvm::Function* getFn = module->getFunction("sun_array_get"); + llvm::Value* elemPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem"); + + llvm::Value* elemVal; + std::string elemClassName; + + if (arrayType == "StringArray") { + elemVal = elemPtr; // Already i8* + elemClassName = "String"; + } else { + // Assume IntArray + elemVal = builder->CreatePtrToInt(elemPtr, llvm::Type::getInt32Ty(*context), "elemInt"); + elemClassName = ""; + } + + // Create variable for element + // Check if variable already exists (shadowing?) or create new scope? + // For simplicity, create new alloca in entry block (or reuse if exists) + // But we are inside a loop, so we should probably create it once in entry block. + // However, we need to update it every iteration. + + llvm::AllocaInst* varAlloca; + if (namedValues.find(node.variableName) == namedValues.end()) { + varAlloca = createEntryBlockAlloca(theFunction, node.variableName, elemVal->getType()); + namedValues[node.variableName] = varAlloca; + } else { + varAlloca = namedValues[node.variableName]; + } + + builder->CreateStore(elemVal, varAlloca); + + // Update type tracking + if (!elemClassName.empty()) { + varTypes[node.variableName] = elemClassName; + } + + // Execute body + breakStack.push_back(afterBB); + node.body->accept(*this); + breakStack.pop_back(); + + // Increment index + llvm::Value* nextIndex = builder->CreateAdd(indexVal, llvm::ConstantInt::get(*context, llvm::APInt(32, 1)), "nextindex"); + builder->CreateStore(nextIndex, indexAlloca); + + builder->CreateBr(condBB); + + // After Block + theFunction->insert(theFunction->end(), afterBB); + builder->SetInsertPoint(afterBB); +} + +void CodeGen::visit(ExpressionStmt& node) { + // Check if it's a print call (hack for now, since we don't have function calls as expressions yet) + // Actually, we don't have CallExpr yet. + // But we can check if the expression is a special "print" node? + // No, let's assume print is a function call. + // But we don't have CallExpr. + // Let's implement CallExpr? Or just handle it in Parser as a statement? + // The user asked for print() function. + // If I implement CallExpr, I can call any function. + // For now, let's assume print is handled via a special AST node or just CallExpr. + // I haven't implemented CallExpr yet. + // Let's implement CallExpr in AST. + node.expression->accept(*this); +} + +void CodeGen::visit(FunctionDef& node) { + // Save current namedValues (which might be main's locals) + auto oldNamedValues = namedValues; + + // 1. Define Function Type (Int32 -> Int32, Int32...) + std::vector ints(node.args.size(), llvm::Type::getInt32Ty(*context)); + llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), ints, false); + + // 2. Create Function + llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, node.name, module.get()); + + // 3. Set Argument Names + unsigned idx = 0; + for (auto& arg : f->args()) { + arg.setName(node.args[idx++]); + } + + // 4. Create Entry Block + llvm::BasicBlock* bb = llvm::BasicBlock::Create(*context, "entry", f); + builder->SetInsertPoint(bb); + + // 5. Record Arguments in Symbol Table (Allocas) + namedValues.clear(); + for (auto& arg : f->args()) { + // Create an alloca for this variable. + llvm::AllocaInst* alloca = createEntryBlockAlloca(f, std::string(arg.getName()), arg.getType()); + + // Store the initial value into the alloca. + builder->CreateStore(&arg, alloca); + + // Add arguments to variable symbol table. + namedValues[std::string(arg.getName())] = alloca; + } + + // 6. Generate Body + node.body->accept(*this); + + // 7. Verify Function + llvm::verifyFunction(*f); + + // Restore namedValues + namedValues = oldNamedValues; +} + +void CodeGen::visit(ClassDef& node) { + // 1. Create Struct Type + std::vector fieldTypes(node.fields.size(), llvm::Type::getInt32Ty(*context)); // All fields are i32 for now + llvm::StructType* structType = llvm::StructType::create(*context, fieldTypes, node.name); + + classStructs[node.name] = structType; + classFields[node.name] = node.fields; +} + +void CodeGen::visit(NewExpr& node) { + if (classStructs.find(node.className) == classStructs.end()) { + std::cerr << "Unknown class: " << node.className << std::endl; + lastValue = nullptr; + lastClassName = ""; + return; + } + + llvm::StructType* structType = classStructs[node.className]; + + // Allocate memory for the object (on heap ideally, but stack for now or malloc) + // For simplicity, let's use malloc via LLVM IR or just alloca if we want stack allocation. + // Let's use alloca for now (stack allocated objects). + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + llvm::AllocaInst* alloca = createEntryBlockAlloca(theFunction, "new_" + node.className, structType); + + lastValue = alloca; // The "object" is a pointer to the struct + lastClassName = node.className; +} + +void CodeGen::visit(GetPropExpr& node) { + // 1. Evaluate object expression + node.object->accept(*this); + llvm::Value* objectPtr = lastValue; + std::string className = lastClassName; + + if (!objectPtr) return; + + if (className.empty()) { + std::cerr << "Cannot determine class type for property access." << std::endl; + lastValue = nullptr; + return; + } + + if (classFields.find(className) == classFields.end()) { + std::cerr << "Unknown class type: " << className << std::endl; + lastValue = nullptr; + return; + } + + const auto& fields = classFields[className]; + int fieldIndex = -1; + for (size_t i = 0; i < fields.size(); ++i) { + if (fields[i] == node.name) { + fieldIndex = i; + break; + } + } + + if (fieldIndex == -1) { + std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl; + lastValue = nullptr; + return; + } + + // 3. Generate GEP and Load + std::vector indices; + indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); // Dereference pointer + indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); // Field index + + llvm::Type* structType = classStructs[className]; + llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr"); + lastValue = builder->CreateLoad(llvm::Type::getInt32Ty(*context), fieldPtr, "fieldval"); + lastClassName = ""; // Field is int +} + +void CodeGen::visit(SetPropExpr& node) { + // 1. Evaluate object + node.object->accept(*this); + llvm::Value* objectPtr = lastValue; + std::string className = lastClassName; + + if (!objectPtr) return; + + // 2. Evaluate value + node.value->accept(*this); + llvm::Value* val = lastValue; + + if (className.empty()) { + std::cerr << "Cannot determine class type for property set." << std::endl; + return; + } + + const auto& fields = classFields[className]; + int fieldIndex = -1; + for (size_t i = 0; i < fields.size(); ++i) { + if (fields[i] == node.name) { + fieldIndex = i; + break; + } + } + + if (fieldIndex == -1) { + std::cerr << "Unknown field: " << node.name << " in class " << className << std::endl; + return; + } + + std::vector indices; + indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, 0))); + indices.push_back(llvm::ConstantInt::get(*context, llvm::APInt(32, fieldIndex))); + + llvm::Type* structType = classStructs[className]; + llvm::Value* fieldPtr = builder->CreateGEP(structType, objectPtr, indices, "fieldptr"); + builder->CreateStore(val, fieldPtr); + lastValue = val; +} + +void CodeGen::visit(ArrayExpr& node) { + int size = node.elements.size(); + llvm::Function* createFn = module->getFunction("sun_array_create"); + llvm::Function* setFn = module->getFunction("sun_array_set"); + + llvm::Value* sizeVal = llvm::ConstantInt::get(*context, llvm::APInt(32, size)); + llvm::Value* arrPtr = builder->CreateCall(createFn, {sizeVal}, "array"); + + std::string elemType = "IntArray"; // Default to IntArray + + for (int i = 0; i < size; ++i) { + node.elements[i]->accept(*this); + llvm::Value* val = lastValue; + + if (lastClassName == "String") { + elemType = "StringArray"; + } + + // Cast value to i8* (void*) + llvm::Value* voidVal; + if (val->getType()->isPointerTy()) { + voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); + } else { + // Assume int (i32), cast to pointer sized int then inttoptr? + // Or just inttoptr directly. + voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); + } + + llvm::Value* indexVal = llvm::ConstantInt::get(*context, llvm::APInt(32, i)); + builder->CreateCall(setFn, {arrPtr, indexVal, voidVal}); + } + + lastValue = arrPtr; + lastClassName = elemType; +} + +void CodeGen::visit(IndexExpr& node) { + node.array->accept(*this); + llvm::Value* arrPtr = lastValue; + std::string arrayType = lastClassName; + + node.index->accept(*this); + llvm::Value* indexVal = lastValue; + + llvm::Function* getFn = module->getFunction("sun_array_get"); + llvm::Value* valPtr = builder->CreateCall(getFn, {arrPtr, indexVal}, "elem"); + + if (arrayType == "StringArray") { + lastValue = valPtr; // Already i8* + lastClassName = "String"; + } else { + // Assume IntArray, cast back to i32 + lastValue = builder->CreatePtrToInt(valPtr, llvm::Type::getInt32Ty(*context), "elemInt"); + lastClassName = ""; + } +} + +void CodeGen::visit(ArrayAssignExpr& node) { + node.array->accept(*this); + llvm::Value* arrPtr = lastValue; + + node.index->accept(*this); + llvm::Value* indexVal = lastValue; + + node.value->accept(*this); + llvm::Value* val = lastValue; + + llvm::Function* setFn = module->getFunction("sun_array_set"); + + // Cast value to i8* + llvm::Value* voidVal; + if (val->getType()->isPointerTy()) { + voidVal = builder->CreateBitCast(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); + } else { + voidVal = builder->CreateIntToPtr(val, llvm::PointerType::get(llvm::Type::getInt8Ty(*context), 0)); + } + + builder->CreateCall(setFn, {arrPtr, indexVal, voidVal}); + lastValue = val; +} + +void CodeGen::visit(SwitchStmt& node) { + // 1. Evaluate condition + node.condition->accept(*this); + llvm::Value* condVal = lastValue; + if (!condVal) return; + + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + + // 2. Create Merge Block (after switch) + llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(*context, "switchmerge"); + + // Push mergeBB to breakStack so 'break' jumps here + breakStack.push_back(mergeBB); + + // 3. Create Default Block + llvm::BasicBlock* defaultBB = llvm::BasicBlock::Create(*context, "switchdefault"); + + // 4. Create Blocks for Cases + std::vector caseBBs; + for (size_t i = 0; i < node.cases.size(); ++i) { + caseBBs.push_back(llvm::BasicBlock::Create(*context, "case" + std::to_string(i))); + } + + // 5. Generate Comparisons (If-Else Chain) + llvm::BasicBlock* currentTestBB = builder->GetInsertBlock(); + + for (size_t i = 0; i < node.cases.size(); ++i) { + // Evaluate case value + node.cases[i].value->accept(*this); + llvm::Value* caseVal = lastValue; + + // Compare + llvm::Value* cmp = builder->CreateICmpEQ(condVal, caseVal, "casecmp"); + + llvm::BasicBlock* nextTestBB; + if (i < node.cases.size() - 1) { + nextTestBB = llvm::BasicBlock::Create(*context, "nexttest", theFunction); + } else { + nextTestBB = defaultBB; + } + + builder->CreateCondBr(cmp, caseBBs[i], nextTestBB); + + if (i < node.cases.size() - 1) { + builder->SetInsertPoint(nextTestBB); + currentTestBB = nextTestBB; + } + } + + if (node.cases.empty()) { + builder->CreateBr(defaultBB); + } + + // 6. Generate Case Bodies + for (size_t i = 0; i < node.cases.size(); ++i) { + theFunction->insert(theFunction->end(), caseBBs[i]); + builder->SetInsertPoint(caseBBs[i]); + + node.cases[i].body->accept(*this); + + // Fallthrough + if (!builder->GetInsertBlock()->getTerminator()) { + if (i < node.cases.size() - 1) { + builder->CreateBr(caseBBs[i+1]); + } else { + builder->CreateBr(defaultBB); + } + } + } + + // 7. Generate Default Body + theFunction->insert(theFunction->end(), defaultBB); + builder->SetInsertPoint(defaultBB); + if (node.defaultCase) { + node.defaultCase->accept(*this); + } + + if (!builder->GetInsertBlock()->getTerminator()) { + builder->CreateBr(mergeBB); + } + + // 8. Finish + breakStack.pop_back(); + theFunction->insert(theFunction->end(), mergeBB); + builder->SetInsertPoint(mergeBB); +} + +void CodeGen::visit(BreakStmt& node) { + if (breakStack.empty()) { + std::cerr << "Error: break statement outside loop or switch" << std::endl; + return; + } + + builder->CreateBr(breakStack.back()); + + llvm::Function* theFunction = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock* deadBB = llvm::BasicBlock::Create(*context, "dead"); + theFunction->insert(theFunction->end(), deadBB); + builder->SetInsertPoint(deadBB); +} diff --git a/src/codegen.h b/src/codegen.h new file mode 100644 index 0000000..dbcfb82 --- /dev/null +++ b/src/codegen.h @@ -0,0 +1,63 @@ +#ifndef CODEGEN_H +#define CODEGEN_H + +#include "ast.h" +#include +#include +#include +#include +#include +#include + +class CodeGen : public ASTVisitor { +public: + CodeGen(); + void generate(ASTNode& node); + void finish(); + void print(); + llvm::Module* getModule() { return module.get(); } + + void visit(NumberExpr& node) override; + void visit(StringExpr& node) override; + void visit(VariableExpr& node) override; + void visit(AssignExpr& node) override; + void visit(CallExpr& node) override; + void visit(NewExpr& node) override; + void visit(GetPropExpr& node) override; + void visit(SetPropExpr& node) override; + void visit(ArrayExpr& node) override; + void visit(IndexExpr& node) override; + void visit(ArrayAssignExpr& node) override; + void visit(BinaryExpr& node) override; + void visit(ReturnStmt& node) override; + void visit(VarDeclStmt& node) override; + void visit(BlockStmt& node) override; + void visit(IfStmt& node) override; + void visit(WhileStmt& node) override; + void visit(ForStmt& node) override; + void visit(ForInStmt& node) override; + void visit(SwitchStmt& node) override; + void visit(BreakStmt& node) override; + void visit(ExpressionStmt& node) override; + void visit(FunctionDef& node) override; + void visit(ClassDef& node) override; + +private: + std::unique_ptr context; + std::unique_ptr module; + std::unique_ptr> builder; + std::map namedValues; + std::map classStructs; // Map class name to LLVM StructType + std::map> classFields; // Map class name to field names (for index lookup) + std::map varTypes; // Map variable name to class name (simple type tracking) + llvm::Value* lastValue; // To store the result of expressions + std::string lastClassName; // To track the type of the last expression (for property access) + std::vector breakStack; // Stack of blocks to jump to on break + + llvm::Function* mainFunction = nullptr; + llvm::BasicBlock* mainInsertBlock = nullptr; + + llvm::AllocaInst* createEntryBlockAlloca(llvm::Function* theFunction, const std::string& varName, llvm::Type* type); +}; + +#endif // CODEGEN_H diff --git a/src/lexer.cpp b/src/lexer.cpp new file mode 100644 index 0000000..138aa56 --- /dev/null +++ b/src/lexer.cpp @@ -0,0 +1,216 @@ +#include "lexer.h" +#include + +Lexer::Lexer(const std::string& source) + : source(source), position(0), length(source.length()), line(1), column(1) {} + +std::vector Lexer::tokenize() { + std::vector tokens; + while (true) { + if (!tokenBuffer.empty()) { + tokens.push_back(tokenBuffer.front()); + tokenBuffer.erase(tokenBuffer.begin()); + if (tokens.back().type == TokenType::END_OF_FILE) break; + continue; + } + + Token token = scanToken(); + if (token.type != TokenType::UNKNOWN) { // Skip whitespace/comments if handled that way + tokens.push_back(token); + } + if (token.type == TokenType::END_OF_FILE) break; + } + return tokens; +} + +char Lexer::peek(int offset) const { + if (position + offset >= length) return '\0'; + return source[position + offset]; +} + +char Lexer::advance() { + char current = source[position]; + position++; + column++; + if (current == '\n') { + line++; + column = 1; + } + return current; +} + +bool Lexer::isAtEnd() const { + return position >= length; +} + +void Lexer::skipWhitespace() { + while (true) { + char c = peek(); + switch (c) { + case ' ': + case '\r': + case '\t': + advance(); + break; + case '\n': + advance(); + break; + default: + return; + } + } +} + +Token Lexer::makeToken(TokenType type, std::string value) { + return {type, value, line, column}; +} + +Token Lexer::scanToken() { + if (resumingString) { + resumingString = false; + return stringPart(); + } + + skipWhitespace(); + if (isAtEnd()) return makeToken(TokenType::END_OF_FILE, ""); + + char c = advance(); + + if (isalpha(c) || c == '_') { + // Identifier or Keyword + position--; // Backtrack to include first char + column--; + return identifierOrKeyword(); + } + + if (isdigit(c)) { + position--; + column--; + return number(); + } + + switch (c) { + case '(': return makeToken(TokenType::LPAREN, "("); + case ')': return makeToken(TokenType::RPAREN, ")"); + case '{': return makeToken(TokenType::LBRACE, "{"); + case '[': return makeToken(TokenType::LBRACKET, "["); + case ']': return makeToken(TokenType::RBRACKET, "]"); + case '}': + if (interpolationDepth > 0) { + interpolationDepth--; + resumingString = true; + tokenBuffer.push_back(makeToken(TokenType::PLUS, "+")); + return makeToken(TokenType::RPAREN, ")"); + } + return makeToken(TokenType::RBRACE, "}"); + case ',': return makeToken(TokenType::COMMA, ","); + case '.': return makeToken(TokenType::DOT, "."); + case ':': return makeToken(TokenType::COLON, ":"); + case ';': return makeToken(TokenType::SEMICOLON, ";"); + case '+': return makeToken(TokenType::PLUS, "+"); + case '-': return makeToken(TokenType::MINUS, "-"); + case '*': return makeToken(TokenType::STAR, "*"); + case '/': + if (peek() == '/') { + while (peek() != '\n' && !isAtEnd()) advance(); + return scanToken(); + } + return makeToken(TokenType::SLASH, "/"); + case '=': + if (peek() == '=') { + advance(); + return makeToken(TokenType::EQUAL_EQUAL, "=="); + } + return makeToken(TokenType::EQUALS, "="); + case '<': return makeToken(TokenType::LESS, "<"); + case '>': return makeToken(TokenType::GREATER, ">"); + case '"': return string(); + default: return makeToken(TokenType::UNKNOWN, std::string(1, c)); + } +} + +Token Lexer::identifierOrKeyword() { + size_t start = position; + while (isalnum(peek()) || peek() == '_') { + advance(); + } + + std::string text = source.substr(start, position - start); + + if (text == "function") return makeToken(TokenType::FUNCTION, text); + if (text == "return") return makeToken(TokenType::RETURN, text); + if (text == "if") return makeToken(TokenType::IF, text); + if (text == "else") return makeToken(TokenType::ELSE, text); + if (text == "while") return makeToken(TokenType::WHILE, text); + if (text == "var") return makeToken(TokenType::VAR, text); + if (text == "class") return makeToken(TokenType::CLASS, text); + if (text == "new") return makeToken(TokenType::NEW, text); + if (text == "for") return makeToken(TokenType::FOR, text); + if (text == "in") return makeToken(TokenType::IN, text); + if (text == "foreach") return makeToken(TokenType::FOREACH, text); + if (text == "switch") return makeToken(TokenType::SWITCH, text); + if (text == "case") return makeToken(TokenType::CASE, text); + if (text == "default") return makeToken(TokenType::DEFAULT, text); + if (text == "break") return makeToken(TokenType::BREAK, text); + + return makeToken(TokenType::IDENTIFIER, text); +} + +Token Lexer::number() { + size_t start = position; + while (isdigit(peek())) { + advance(); + } + return makeToken(TokenType::NUMBER, source.substr(start, position - start)); +} + +Token Lexer::string() { + std::string value = ""; + while (peek() != '"' && !isAtEnd()) { + if (peek() == '$' && peek(1) == '{') { + interpolationDepth++; + advance(); // $ + advance(); // { + tokenBuffer.push_back(makeToken(TokenType::PLUS, "+")); + tokenBuffer.push_back(makeToken(TokenType::LPAREN, "(")); + return makeToken(TokenType::STRING, value); + } + + if (peek() == '\n') line++; + value += advance(); + } + + if (isAtEnd()) { + // Error: Unterminated string. + return makeToken(TokenType::UNKNOWN, "Unterminated string"); + } + + // The closing ". + advance(); + + return makeToken(TokenType::STRING, value); +} + +Token Lexer::stringPart() { + std::string value = ""; + while (peek() != '"' && !isAtEnd()) { + if (peek() == '$' && peek(1) == '{') { + interpolationDepth++; + advance(); // $ + advance(); // { + tokenBuffer.push_back(makeToken(TokenType::PLUS, "+")); + tokenBuffer.push_back(makeToken(TokenType::LPAREN, "(")); + return makeToken(TokenType::STRING, value); + } + + if (peek() == '\n') line++; + value += advance(); + } + + if (isAtEnd()) { + return makeToken(TokenType::UNKNOWN, "Unterminated string"); + } + + advance(); // The closing " + return makeToken(TokenType::STRING, value); +} diff --git a/src/lexer.h b/src/lexer.h new file mode 100644 index 0000000..aeae32a --- /dev/null +++ b/src/lexer.h @@ -0,0 +1,36 @@ +#ifndef LEXER_H +#define LEXER_H + +#include +#include +#include "token.h" + +class Lexer { +public: + Lexer(const std::string& source); + std::vector tokenize(); + +private: + std::string source; + size_t position; + size_t length; + int line; + int column; + + std::vector tokenBuffer; + int interpolationDepth = 0; + bool resumingString = false; + + char peek(int offset = 0) const; + char advance(); + bool isAtEnd() const; + void skipWhitespace(); + Token makeToken(TokenType type, std::string value); + Token scanToken(); + Token identifierOrKeyword(); + Token number(); + Token string(); + Token stringPart(); +}; + +#endif // LEXER_H diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 0000000..b81b27c --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,169 @@ +#include +#include +#include +#include +#include "lexer.h" +#include "parser.h" +#include "ast_printer.h" +#include "codegen.h" +#include +#include + +const char* RUNTIME_CODE = R"( +#include +#include +#include + +extern "C" { + +void print_int(int x) { + printf("%d\n", x); +} + +void print_string(char* x) { + printf("%s\n", x); +} + +char* int_to_str(int x) { + char* buffer = (char*)malloc(12); // Enough for 32-bit int + snprintf(buffer, 12, "%d", x); + return buffer; +} + +char* str_concat(char* a, char* b) { + int lenA = strlen(a); + int lenB = strlen(b); + char* result = (char*)malloc(lenA + lenB + 1); + strcpy(result, a); + strcat(result, b); + return result; +} + +typedef struct { + int size; + void** data; +} Array; + +void* sun_array_create(int size) { + Array* arr = (Array*)malloc(sizeof(Array)); + arr->size = size; + arr->data = (void**)malloc(sizeof(void*) * size); + return (void*)arr; +} + +void sun_array_set(void* arrPtr, int index, void* value) { + Array* arr = (Array*)arrPtr; + if (index >= 0 && index < arr->size) { + arr->data[index] = value; + } else { + printf("Error: Array index out of bounds: %d\n", index); + exit(1); + } +} + +void* sun_array_get(void* arrPtr, int index) { + Array* arr = (Array*)arrPtr; + if (index >= 0 && index < arr->size) { + return arr->data[index]; + } else { + printf("Error: Array index out of bounds: %d\n", index); + exit(1); + } + return NULL; +} + +int sun_array_length(void* arrPtr) { + Array* arr = (Array*)arrPtr; + return arr->size; +} + +} +)"; + +int main(int argc, char* argv[]) { + if (argc < 2) { + std::cerr << "Usage: sun [-o output]" << std::endl; + return 1; + } + + std::string filename = argv[1]; + std::string outputExe = "a.out"; + bool runImmediately = true; + + for (int i = 2; i < argc; i++) { + if (std::string(argv[i]) == "-o" && i + 1 < argc) { + outputExe = argv[i+1]; + runImmediately = false; + i++; + } + } + + if (runImmediately) { + outputExe = "temp_sun_prog"; + } + + std::ifstream file(filename); + if (!file.is_open()) { + std::cerr << "Error: Could not open file " << filename << std::endl; + return 1; + } + + std::stringstream buffer; + buffer << file.rdbuf(); + std::string source = buffer.str(); + + // 1. Lexing + Lexer lexer(source); + std::vector tokens = lexer.tokenize(); + + // 2. Parsing + Parser parser(tokens); + std::vector> nodes; + while (!parser.isAtEnd()) { + nodes.push_back(parser.parseTopLevel()); + } + + // 3. Code Generation + CodeGen codegen; + for (const auto& node : nodes) { + codegen.generate(*node); + } + codegen.finish(); + + // 4. Write LLVM IR to temp file + std::error_code EC; + llvm::raw_fd_ostream dest("temp.ll", EC, llvm::sys::fs::OF_None); + if (EC) { + std::cerr << "Could not open temp.ll: " << EC.message() << std::endl; + return 1; + } + codegen.getModule()->print(dest, nullptr); + dest.flush(); + dest.close(); + + // 5. Write Runtime to temp file + std::ofstream runtimeFile("temp_runtime.cpp"); + runtimeFile << RUNTIME_CODE; + runtimeFile.close(); + + // 6. Compile and Link + std::string cmd = "clang++ -Wno-override-module temp.ll temp_runtime.cpp -o " + outputExe; + int ret = std::system(cmd.c_str()); + + // 7. Cleanup + remove("temp.ll"); + remove("temp_runtime.cpp"); + + if (ret == 0) { + if (runImmediately) { + std::string runCmd = "./" + outputExe; + int runRet = std::system(runCmd.c_str()); + remove(outputExe.c_str()); + return runRet; + } else { + std::cout << "Compilation successful. Output: " << outputExe << std::endl; + } + } + + return ret; +} diff --git a/src/parser.cpp b/src/parser.cpp new file mode 100644 index 0000000..44bfaf1 --- /dev/null +++ b/src/parser.cpp @@ -0,0 +1,465 @@ +#include "parser.h" +#include + +Parser::Parser(const std::vector& tokens) : tokens(tokens), current(0) {} + +const Token& Parser::peek() const { + return tokens[current]; +} + +const Token& Parser::previous() const { + return tokens[current - 1]; +} + +bool Parser::isAtEnd() const { + return peek().type == TokenType::END_OF_FILE; +} + +bool Parser::check(TokenType type) const { + if (isAtEnd()) return false; + return peek().type == type; +} + +const Token& Parser::advance() { + if (!isAtEnd()) current++; + return previous(); +} + +const Token& Parser::consume(TokenType type, const std::string& message) { + if (check(type)) return advance(); + std::cerr << "Error: " << message << " at line " << peek().line << std::endl; + exit(1); +} + +bool Parser::match(TokenType type) { + if (check(type)) { + advance(); + return true; + } + return false; +} + +std::unique_ptr Parser::parseFunction() { + consume(TokenType::FUNCTION, "Expect 'function' keyword."); + std::string name = consume(TokenType::IDENTIFIER, "Expect function name.").value; + + consume(TokenType::LPAREN, "Expect '(' after function name."); + std::vector args; + if (!check(TokenType::RPAREN)) { + do { + args.push_back(consume(TokenType::IDENTIFIER, "Expect parameter name.").value); + } while (match(TokenType::COMMA)); + } + consume(TokenType::RPAREN, "Expect ')' after parameters."); + + consume(TokenType::LBRACE, "Expect '{' before function body."); + auto body = block(); + + return std::make_unique(name, args, std::move(body)); +} + +std::unique_ptr Parser::block() { + auto block = std::make_unique(); + while (!check(TokenType::RBRACE) && !isAtEnd()) { + block->statements.push_back(statement()); + } + consume(TokenType::RBRACE, "Expect '}' after block."); + return block; +} + +std::unique_ptr Parser::parseClass() { + consume(TokenType::CLASS, "Expect 'class' keyword."); + std::string name = consume(TokenType::IDENTIFIER, "Expect class name.").value; + consume(TokenType::LBRACE, "Expect '{' before class body."); + + std::vector fields; + while (!check(TokenType::RBRACE) && !isAtEnd()) { + consume(TokenType::VAR, "Expect 'var' for field declaration."); + std::string fieldName = consume(TokenType::IDENTIFIER, "Expect field name.").value; + consume(TokenType::SEMICOLON, "Expect ';' after field name."); + fields.push_back(fieldName); + } + + consume(TokenType::RBRACE, "Expect '}' after class body."); + return std::make_unique(name, fields); +} + +std::unique_ptr Parser::parseTopLevel() { + if (check(TokenType::CLASS)) { + return parseClass(); + } else if (check(TokenType::FUNCTION)) { + return parseFunction(); + } else { + return statement(); + } +} + +std::unique_ptr Parser::statement() { + if (match(TokenType::RETURN)) { + return returnStatement(); + } + if (match(TokenType::IF)) { + return ifStatement(); + } + if (match(TokenType::WHILE)) { + return whileStatement(); + } + if (match(TokenType::FOR)) { + return forStatement(); + } + if (match(TokenType::FOREACH)) { + return foreachStatement(); + } + if (match(TokenType::SWITCH)) { + return switchStatement(); + } + if (match(TokenType::BREAK)) { + return breakStatement(); + } + if (match(TokenType::VAR)) { + return varDeclaration(); + } + if (match(TokenType::LBRACE)) { + return block(); + } + return expressionStatement(); +} + +std::unique_ptr Parser::returnStatement() { + std::unique_ptr value = expression(); + consume(TokenType::SEMICOLON, "Expect ';' after return value."); + return std::make_unique(std::move(value)); +} + +std::unique_ptr Parser::ifStatement() { + consume(TokenType::LPAREN, "Expect '(' after 'if'."); + std::unique_ptr condition = expression(); + consume(TokenType::RPAREN, "Expect ')' after if condition."); + + std::unique_ptr thenBranch = statement(); + std::unique_ptr elseBranch = nullptr; + + if (match(TokenType::ELSE)) { + elseBranch = statement(); + } + + return std::make_unique(std::move(condition), std::move(thenBranch), std::move(elseBranch)); +} + +std::unique_ptr Parser::whileStatement() { + consume(TokenType::LPAREN, "Expect '(' after 'while'."); + std::unique_ptr condition = expression(); + consume(TokenType::RPAREN, "Expect ')' after while condition."); + std::unique_ptr body = statement(); + + return std::make_unique(std::move(condition), std::move(body)); +} + +std::unique_ptr Parser::forStatement() { + consume(TokenType::LPAREN, "Expect '(' after 'for'."); + + // Check for 'var x in y' + if (check(TokenType::VAR)) { + // Look ahead to see if it's a for-in loop + // We need to peek 2 tokens ahead: VAR IDENTIFIER IN + // But we only have peek() and previous(). + // Let's consume VAR and IDENTIFIER, then check for IN. + // If not IN, we backtrack? No, parser is single pass. + // But standard for loop init can be 'var x = ...'. + // So: + // var x = ... -> standard + // var x in ... -> for-in + + // We can parse the var declaration part partially. + // But varDeclaration() consumes the semicolon. + + // Let's handle it manually here. + Token varToken = peek(); // VAR + advance(); // consume VAR + + std::string name = consume(TokenType::IDENTIFIER, "Expect variable name.").value; + + if (match(TokenType::IN)) { + // It is a for-in loop: for (var x in collection) + std::unique_ptr collection = expression(); + consume(TokenType::RPAREN, "Expect ')' after loop collection."); + std::unique_ptr body = statement(); + return std::make_unique(name, std::move(collection), std::move(body)); + } + + // It is a standard for loop: for (var x = 0; ...) + std::unique_ptr initializer = nullptr; + if (match(TokenType::EQUALS)) { + initializer = expression(); + } + consume(TokenType::SEMICOLON, "Expect ';' after variable declaration."); + std::unique_ptr init = std::make_unique(name, std::move(initializer)); + + // Continue parsing standard for loop + std::unique_ptr condition = nullptr; + if (!check(TokenType::SEMICOLON)) { + condition = expression(); + } + consume(TokenType::SEMICOLON, "Expect ';' after loop condition."); + + std::unique_ptr increment = nullptr; + if (!check(TokenType::RPAREN)) { + increment = expression(); + } + consume(TokenType::RPAREN, "Expect ')' after for clauses."); + + std::unique_ptr body = statement(); + return std::make_unique(std::move(init), std::move(condition), std::move(increment), std::move(body)); + } + + std::unique_ptr init = nullptr; + if (match(TokenType::SEMICOLON)) { + init = nullptr; + } else { + init = expressionStatement(); + } + + std::unique_ptr condition = nullptr; + if (!check(TokenType::SEMICOLON)) { + condition = expression(); + } + consume(TokenType::SEMICOLON, "Expect ';' after loop condition."); + + std::unique_ptr increment = nullptr; + if (!check(TokenType::RPAREN)) { + increment = expression(); + } + consume(TokenType::RPAREN, "Expect ')' after for clauses."); + + std::unique_ptr body = statement(); + + return std::make_unique(std::move(init), std::move(condition), std::move(increment), std::move(body)); +} + +std::unique_ptr Parser::foreachStatement() { + consume(TokenType::LPAREN, "Expect '(' after 'foreach'."); + + // foreach (var x in y) or foreach (x in y) + // Let's enforce 'var' for now or optional? + // User said "foreach (item in array)". + + std::string name; + if (match(TokenType::VAR)) { + name = consume(TokenType::IDENTIFIER, "Expect variable name.").value; + } else { + name = consume(TokenType::IDENTIFIER, "Expect variable name.").value; + } + + consume(TokenType::IN, "Expect 'in' after variable name."); + std::unique_ptr collection = expression(); + consume(TokenType::RPAREN, "Expect ')' after loop collection."); + + std::unique_ptr body = statement(); + return std::make_unique(name, std::move(collection), std::move(body)); +} + +std::unique_ptr Parser::switchStatement() { + consume(TokenType::LPAREN, "Expect '(' after 'switch'."); + std::unique_ptr condition = expression(); + consume(TokenType::RPAREN, "Expect ')' after switch condition."); + consume(TokenType::LBRACE, "Expect '{' before switch body."); + + std::vector cases; + std::unique_ptr defaultCase = nullptr; + + while (!check(TokenType::RBRACE) && !isAtEnd()) { + if (match(TokenType::CASE)) { + std::unique_ptr value = expression(); + consume(TokenType::COLON, "Expect ':' after case value."); + + // Parse statements until next case/default/end + std::vector> stmts; + while (!check(TokenType::CASE) && !check(TokenType::DEFAULT) && !check(TokenType::RBRACE) && !isAtEnd()) { + stmts.push_back(statement()); + } + + auto block = std::make_unique(); + block->statements = std::move(stmts); + cases.push_back({std::move(value), std::move(block)}); + } else if (match(TokenType::DEFAULT)) { + consume(TokenType::COLON, "Expect ':' after default."); + + std::vector> stmts; + while (!check(TokenType::CASE) && !check(TokenType::DEFAULT) && !check(TokenType::RBRACE) && !isAtEnd()) { + stmts.push_back(statement()); + } + + auto block = std::make_unique(); + block->statements = std::move(stmts); + defaultCase = std::move(block); + } else { + // Error or unexpected token + std::cerr << "Expect 'case' or 'default' inside switch." << std::endl; + exit(1); + } + } + + consume(TokenType::RBRACE, "Expect '}' after switch body."); + return std::make_unique(std::move(condition), std::move(cases), std::move(defaultCase)); +} + +std::unique_ptr Parser::breakStatement() { + consume(TokenType::SEMICOLON, "Expect ';' after 'break'."); + return std::make_unique(); +} + +std::unique_ptr Parser::varDeclaration() { + std::string name = consume(TokenType::IDENTIFIER, "Expect variable name.").value; + std::unique_ptr initializer = nullptr; + if (match(TokenType::EQUALS)) { + initializer = expression(); + } + consume(TokenType::SEMICOLON, "Expect ';' after variable declaration."); + return std::make_unique(name, std::move(initializer)); +} + +std::unique_ptr Parser::expressionStatement() { + std::unique_ptr expr = expression(); + consume(TokenType::SEMICOLON, "Expect ';' after expression."); + return std::make_unique(std::move(expr)); +} + +std::unique_ptr Parser::expression() { + return assignment(); +} + +std::unique_ptr Parser::assignment() { + std::unique_ptr expr = equality(); + + if (match(TokenType::EQUALS)) { + Token equals = previous(); + std::unique_ptr value = assignment(); + + if (auto varExpr = dynamic_cast(expr.get())) { + std::string name = varExpr->name; + return std::make_unique(name, std::move(value)); + } else if (auto getProp = dynamic_cast(expr.get())) { + return std::make_unique(std::move(getProp->object), getProp->name, std::move(value)); + } else if (auto indexExpr = dynamic_cast(expr.get())) { + return std::make_unique(std::move(indexExpr->array), std::move(indexExpr->index), std::move(value)); + } + + std::cerr << "Error: Invalid assignment target." << std::endl; + exit(1); + } + + return expr; +} + +std::unique_ptr Parser::equality() { + std::unique_ptr expr = comparison(); + + while (match(TokenType::EQUAL_EQUAL)) { // Add != later + std::string op = previous().value; + std::unique_ptr right = comparison(); + expr = std::make_unique(std::move(expr), op, std::move(right)); + } + + return expr; +} + +std::unique_ptr Parser::comparison() { + std::unique_ptr expr = term(); + + while (match(TokenType::LESS) || match(TokenType::GREATER)) { + std::string op = previous().value; + std::unique_ptr right = term(); + expr = std::make_unique(std::move(expr), op, std::move(right)); + } + + return expr; +} + +std::unique_ptr Parser::term() { + std::unique_ptr expr = factor(); + + while (match(TokenType::PLUS) || match(TokenType::MINUS)) { + std::string op = previous().value; + std::unique_ptr right = factor(); + expr = std::make_unique(std::move(expr), op, std::move(right)); + } + + return expr; +} + +std::unique_ptr Parser::factor() { + std::unique_ptr expr = call(); + + while (match(TokenType::STAR) || match(TokenType::SLASH)) { + std::string op = previous().value; + std::unique_ptr right = call(); + expr = std::make_unique(std::move(expr), op, std::move(right)); + } + + return expr; +} + +std::unique_ptr Parser::call() { + std::unique_ptr expr = primary(); + + while (true) { + if (match(TokenType::DOT)) { + std::string name = consume(TokenType::IDENTIFIER, "Expect property name after '.'.").value; + expr = std::make_unique(std::move(expr), name); + } else if (match(TokenType::LBRACKET)) { + std::unique_ptr index = expression(); + consume(TokenType::RBRACKET, "Expect ']' after index."); + expr = std::make_unique(std::move(expr), std::move(index)); + } else { + break; + } + } + + return expr; +} + +std::unique_ptr Parser::primary() { + if (match(TokenType::NEW)) { + std::string className = consume(TokenType::IDENTIFIER, "Expect class name.").value; + consume(TokenType::LPAREN, "Expect '(' after class name."); + consume(TokenType::RPAREN, "Expect ')' after class name."); + return std::make_unique(className); + } + if (match(TokenType::NUMBER)) { + return std::make_unique(std::stoi(previous().value)); + } + if (match(TokenType::STRING)) { + return std::make_unique(previous().value); + } + if (match(TokenType::IDENTIFIER)) { + std::string name = previous().value; + if (match(TokenType::LPAREN)) { + std::vector> args; + if (!check(TokenType::RPAREN)) { + do { + args.push_back(expression()); + } while (match(TokenType::COMMA)); + } + consume(TokenType::RPAREN, "Expect ')' after arguments."); + return std::make_unique(name, std::move(args)); + } + return std::make_unique(name); + } + if (match(TokenType::LPAREN)) { + std::unique_ptr expr = expression(); + consume(TokenType::RPAREN, "Expect ')' after expression."); + return expr; + } + if (match(TokenType::LBRACKET)) { + std::vector> elements; + if (!check(TokenType::RBRACKET)) { + do { + elements.push_back(expression()); + } while (match(TokenType::COMMA)); + } + consume(TokenType::RBRACKET, "Expect ']' after array elements."); + return std::make_unique(std::move(elements)); + } + std::cerr << "Error: Expect expression at line " << peek().line << std::endl; + exit(1); +} diff --git a/src/parser.h b/src/parser.h new file mode 100644 index 0000000..3adbf50 --- /dev/null +++ b/src/parser.h @@ -0,0 +1,50 @@ +#ifndef PARSER_H +#define PARSER_H + +#include +#include +#include "token.h" +#include "ast.h" + +class Parser { +public: + Parser(const std::vector& tokens); + std::unique_ptr parseFunction(); + std::unique_ptr parseTopLevel(); + bool isAtEnd() const; + +private: + const std::vector& tokens; + size_t current; + + const Token& peek() const; + const Token& previous() const; + bool check(TokenType type) const; + const Token& advance(); + const Token& consume(TokenType type, const std::string& message); + bool match(TokenType type); + + std::unique_ptr expression(); + std::unique_ptr assignment(); + std::unique_ptr equality(); + std::unique_ptr comparison(); + std::unique_ptr term(); // + - + std::unique_ptr factor(); // * / + std::unique_ptr call(); // . () + std::unique_ptr primary(); + + std::unique_ptr statement(); + std::unique_ptr returnStatement(); + std::unique_ptr ifStatement(); + std::unique_ptr whileStatement(); + std::unique_ptr forStatement(); + std::unique_ptr foreachStatement(); + std::unique_ptr switchStatement(); + std::unique_ptr breakStatement(); + std::unique_ptr varDeclaration(); + std::unique_ptr expressionStatement(); + std::unique_ptr block(); + std::unique_ptr parseClass(); +}; + +#endif // PARSER_H diff --git a/src/runtime.cpp b/src/runtime.cpp new file mode 100644 index 0000000..1f7a817 --- /dev/null +++ b/src/runtime.cpp @@ -0,0 +1,68 @@ +#include +#include +#include + +extern "C" { + +void print_int(int x) { + printf("%d\n", x); +} + +void print_string(char* x) { + printf("%s\n", x); +} + +char* int_to_str(int x) { + char* buffer = (char*)malloc(12); // Enough for 32-bit int + snprintf(buffer, 12, "%d", x); + return buffer; +} + +char* str_concat(char* a, char* b) { + int lenA = strlen(a); + int lenB = strlen(b); + char* result = (char*)malloc(lenA + lenB + 1); + strcpy(result, a); + strcat(result, b); + return result; +} + +typedef struct { + int size; + void** data; +} Array; + +void* sun_array_create(int size) { + Array* arr = (Array*)malloc(sizeof(Array)); + arr->size = size; + arr->data = (void**)malloc(sizeof(void*) * size); + return (void*)arr; +} + +void sun_array_set(void* arrPtr, int index, void* value) { + Array* arr = (Array*)arrPtr; + if (index >= 0 && index < arr->size) { + arr->data[index] = value; + } else { + printf("Error: Array index out of bounds: %d\n", index); + exit(1); + } +} + +void* sun_array_get(void* arrPtr, int index) { + Array* arr = (Array*)arrPtr; + if (index >= 0 && index < arr->size) { + return arr->data[index]; + } else { + printf("Error: Array index out of bounds: %d\n", index); + exit(1); + } + return NULL; +} + +int sun_array_length(void* arrPtr) { + Array* arr = (Array*)arrPtr; + return arr->size; +} + +} diff --git a/src/token.h b/src/token.h new file mode 100644 index 0000000..d869176 --- /dev/null +++ b/src/token.h @@ -0,0 +1,91 @@ +#ifndef TOKEN_H +#define TOKEN_H + +#include +#include + +enum class TokenType { + FUNCTION, + RETURN, + IF, + ELSE, + WHILE, + VAR, + CLASS, + NEW, + STRING, // "string" + FOR, // for + IN, // in + FOREACH, // foreach + SWITCH, // switch + CASE, // case + DEFAULT, // default + BREAK, // break + IDENTIFIER, + NUMBER, + LPAREN, // ( + RPAREN, // ) + LBRACE, // { + RBRACE, // } + LBRACKET, // [ + RBRACKET, // ] + COMMA, // , + DOT, // . + COLON, // : + SEMICOLON, // ; + PLUS, // + + MINUS, // - + STAR, // * + SLASH, // / + EQUALS, // = + EQUAL_EQUAL,// == + LESS, // < + GREATER, // > + END_OF_FILE, + UNKNOWN +}; + +struct Token { + TokenType type; + std::string value; + int line; + int column; + + std::string toString() const { + switch (type) { + case TokenType::FUNCTION: return "FUNCTION"; + case TokenType::RETURN: return "RETURN"; + case TokenType::IF: return "IF"; + case TokenType::ELSE: return "ELSE"; + case TokenType::WHILE: return "WHILE"; + case TokenType::VAR: return "VAR"; + case TokenType::CLASS: return "CLASS"; + case TokenType::NEW: return "NEW"; + case TokenType::STRING: return "STRING(" + value + ")"; + case TokenType::FOR: return "FOR"; + case TokenType::IN: return "IN"; + case TokenType::FOREACH: return "FOREACH"; + case TokenType::IDENTIFIER: return "IDENTIFIER(" + value + ")"; + case TokenType::NUMBER: return "NUMBER(" + value + ")"; + case TokenType::LPAREN: return "LPAREN"; + case TokenType::RPAREN: return "RPAREN"; + case TokenType::LBRACE: return "LBRACE"; + case TokenType::RBRACE: return "RBRACE"; + case TokenType::COMMA: return "COMMA"; + case TokenType::DOT: return "DOT"; + case TokenType::SEMICOLON: return "SEMICOLON"; + case TokenType::PLUS: return "PLUS"; + case TokenType::MINUS: return "MINUS"; + case TokenType::STAR: return "STAR"; + case TokenType::SLASH: return "SLASH"; + case TokenType::EQUALS: return "EQUALS"; + case TokenType::EQUAL_EQUAL: return "EQUAL_EQUAL"; + case TokenType::LESS: return "LESS"; + case TokenType::GREATER: return "GREATER"; + case TokenType::END_OF_FILE: return "EOF"; + default: return "UNKNOWN(" + value + ")"; + } + } +}; + +#endif // TOKEN_H diff --git a/tests/test.sun b/tests/test.sun new file mode 100644 index 0000000..f4411d0 --- /dev/null +++ b/tests/test.sun @@ -0,0 +1,5 @@ +function soma(a, b) { + return a + b; +} + +print(soma(2, 3)); \ No newline at end of file diff --git a/tests/test_array.sun b/tests/test_array.sun new file mode 100644 index 0000000..c3e51de --- /dev/null +++ b/tests/test_array.sun @@ -0,0 +1,9 @@ +var arr = [10, 20, 30]; +print(arr[0]); +print(arr[1]); +arr[2] = 50; +print(arr[2]); + +var strArr = ["Hello", "World"]; +print(strArr[0]); +print(strArr[1]); diff --git a/tests/test_class.sun b/tests/test_class.sun new file mode 100644 index 0000000..6c6b7bb --- /dev/null +++ b/tests/test_class.sun @@ -0,0 +1,11 @@ +class Point { + var x; + var y; +} + +function main() { + var p = new Point(); + p.x = 10; + p.y = 20; + return p.x + p.y; +} \ No newline at end of file diff --git a/tests/test_for.sun b/tests/test_for.sun new file mode 100644 index 0000000..7f57a4a --- /dev/null +++ b/tests/test_for.sun @@ -0,0 +1,7 @@ +function main() { + print("Counting:"); + for (var i = 0; i < 5; i = i + 1) { + print(i); + } + return 0; +} \ No newline at end of file diff --git a/tests/test_if.sun b/tests/test_if.sun new file mode 100644 index 0000000..b5cbace --- /dev/null +++ b/tests/test_if.sun @@ -0,0 +1,9 @@ +function test(a) { + var x = 10; + if (a > x) { + return 1; + } else { + x = x + 1; + return x; + } +} diff --git a/tests/test_interpolation.sun b/tests/test_interpolation.sun new file mode 100644 index 0000000..a7c197f --- /dev/null +++ b/tests/test_interpolation.sun @@ -0,0 +1,4 @@ +print("Hello " + "World"); +var name = "Sun"; +print("Hello ${name}!"); +print("Value: ${10 + 20}"); diff --git a/tests/test_loops.sun b/tests/test_loops.sun new file mode 100644 index 0000000..549b120 --- /dev/null +++ b/tests/test_loops.sun @@ -0,0 +1,16 @@ +var arr = [10, 20, 30]; +print("For loop:"); +for (var i = 0; i < 3; i = i + 1) { + print(arr[i]); +} + +print("For-in loop:"); +for (var item in arr) { + print(item); +} + +var strArr = ["Hello", "World"]; +print("Foreach loop:"); +foreach (s in strArr) { + print(s); +} diff --git a/tests/test_print.sun b/tests/test_print.sun new file mode 100644 index 0000000..5270f2e --- /dev/null +++ b/tests/test_print.sun @@ -0,0 +1,6 @@ +function main() { + var name = "World"; + print("Hello " + name); + print(123); + return 0; +} \ No newline at end of file diff --git a/tests/test_switch.sun b/tests/test_switch.sun new file mode 100644 index 0000000..65e74fb --- /dev/null +++ b/tests/test_switch.sun @@ -0,0 +1,31 @@ +var x = 2; +var result = 0; + +switch (x) { + case 1: { + result = 10; + break; + } + case 2: { + result = 20; + // Fallthrough to 3 + } + case 3: { + result = result + 5; + break; + } + default: { + result = 0 - 1; + } +} + +print(result); + +var i = 0; +while (i < 10) { + if (i == 5) { + break; + } + i = i + 1; +} +print(i); diff --git a/tests/test_while.sun b/tests/test_while.sun new file mode 100644 index 0000000..c9141dd --- /dev/null +++ b/tests/test_while.sun @@ -0,0 +1,7 @@ +function loop(max) { + var i = 0; + while (i < max) { + i = i + 1; + } + return i; +}