Initial commit - copied from GitLab

This commit is contained in:
os222 2025-05-24 22:13:16 +01:00
commit 4d6426c899
440 changed files with 15150 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
.metals/
.bsp/
.scala-build/
.vscode/

16
.gitlab-ci.yml Normal file
View File

@ -0,0 +1,16 @@
stages:
- compile
- test
compile:
stage: compile
script:
- scala .
test:
stage: test
script:
- scala test . # This will test everything. Will need to be modified as we add tests
# More stages will be added later to independently test different phases
# To allow chosen tests to run as we add functionality

21
Makefile Normal file
View File

@ -0,0 +1,21 @@
# NOTE: PLEASE DON'T USE THIS MAKEFILE, IT IS FOR LABTS
# it is *much* more efficient to use `scala compile .` trust me, I'm watching you.
all:
# the --server=false flag helps improve performance on LabTS by avoiding
# downloading the build-server "bloop".
# the --jvm system flag helps improve performance on LabTS by preventing
# scala-cli from downloading a whole jdk distribution on the lab machine
# the --force flag ensures that any existing built compiler is overwritten
# the --power flag is needed as `package` is an experimental "power user" feature (NOTE: use this or --assembly if anything goes wrong)
# scala --power package . --server=false --jvm system --force -o wacc-compiler
# you can use --assembly to make it built a self-contained jar,
# scala --power package . --server=false --jvm system --assembly --force -o wacc-compiler
# you can use --native to make it build a native application (requiring Scala Native),
# scala --power package . --server=false --jvm system --native --force -o wacc-compiler
# or you can use --graalvm-jvm-id graalvm-java21 --native-image to build it using graalvm
scala --power package . --server=false --jvm system --graalvm-jvm-id graalvm-java21 --native-image --force -o wacc-compiler
clean:
scala clean . && rm -f wacc-compiler
.PHONY: all clean

35
README.md Normal file
View File

@ -0,0 +1,35 @@
This is the provided git repository for the WACC compilers lab. You should work
in this repository regularly committing and pushing your work back to GitLab.
# Provided files/directories
## src/main
The src/main directory is where you code for your compiler should go, and just
contains a stub hello world file with a simple calculator inside.
## src/test
The src/test directory is where you should put the code for your tests, which
can be ran via `scala-cli test .`. The suggested framework is `scalatest`, the dependency
for which has already been included.
## project.scala
The `project.scala` is the definition of your project's build requirements. By default,
this skeleton has added the latest stable versions of both `scalatest` and `parsley`
to the build: you should check **regularly** to see if your `parsley` needs updating
during the course of WACC!
## compile
The compile script can be edited to change the frontend interface to your WACC
compiler. You are free to change the language used in this script, but do not
change its name.
## Makefile
Your Makefile should be edited so that running 'make' in the root directory
builds your WACC compiler. Currently running 'make' will call
`scala --power package . --server=false --jvm system --graalvm-jvm-id graalvm-java21 --native-image --force -o wacc-compiler`, producing a file called
`wacc-compiler`
in the root directory of the project. If this doesn't work for whatever reason, there are a few
different alternatives you can try in the makefile. **Do not use the makefile as you're working, it's for labts/CI!**

9
compile Executable file
View File

@ -0,0 +1,9 @@
#!/bin/bash
# Bash front-end for your compiler.
# You are free to change the language used for this script,
# but do *not* change its name.
# feel free to adjust to suit the specific internal flags of your compiler
./wacc-compiler "$@"
exit $?

26
project.scala Normal file
View File

@ -0,0 +1,26 @@
//> using scala 3.6
//> using platform jvm
// dependencies
//> using dep com.github.j-mie6::parsley::5.0.0-M12
//> using dep com.github.scopt::scopt::4.1.0
//> using dep com.lihaoyi::os-lib::0.11.4
//> using test.dep org.scalatest::scalatest::3.2.19
// these are all sensible defaults to catch annoying issues
//> using options -deprecation -unchecked -feature
//> using options -Wimplausible-patterns -Wunused:all
//> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores
// these will help ensure you have access to the latest parsley releases
// even before they land on maven proper, or snapshot versions, if necessary.
// just in case they cause problems, however, keep them turned off unless you
// specifically need them.
// using repositories sonatype-s01:releases
// using repositories sonatype-s01:snapshots
// these are flags used by Scala native: if you aren't using scala-native, then they do nothing
// lto-thin has decent linking times, and release-fast does not too much optimisation.
// using nativeLto thin
// using nativeGc commix
// using nativeMode release-fast

109
src/main/wacc/Main.scala Normal file
View File

@ -0,0 +1,109 @@
package wacc.frontend
import parsley._
import scopt.OParser
import wacc.backend.CodeGenerator
import wacc.frontend.semantic._
import wacc.frontend.semantic.environment._
import wacc.frontend.syntax.ast.WProgram
import wacc.frontend.syntax.ImportHandler
import wacc.frontend.syntax.parser
import wacc.frontend.syntax.SYNTAX_ERROR_CODE
import wacc.extension.Peephole
val builder = OParser.builder[Config]
val oParser = {
import builder._
OParser.sequence(
opt[Unit]('p', "peephole")
.action((_, c) => c.copy(peephole = true))
.text("Enable peephole mode"),
opt[Unit]('i', "imports")
.action((_, c) => c.copy(imports = true))
.text("Enable imports mode"),
opt[String]('o', "output")
.action((x, c) => c.copy(output = Some(x)))
.text("Output file")
)
}
// This is the configuration for ar
case class Config(peephole: Boolean = false,
imports: Boolean = false,
output: Option[String] = None,
input: Option[String] = None)
// Example script: scala . -- "src/test/wacc/waccPrograms/valid/advanced/hashTable.wacc"
def main(args: Array[String]): Unit = {
// First handle if the input file is present
val inputFile = args.headOption match {
case Some(file) => file
case _ => "Not Found"
}
// We handle the remaining flags
val (peepholeFlag, importsFlag, outputFlag) = if (args.length < 2) {
(false, false, None)
} else {
OParser.parse(oParser, args.tail, Config()) match {
case Some(config) =>
(config.peephole, config.imports, config.output)
case _ =>
(false, false, None)
}
}
println("Hello WACC")
if (inputFile == "Not Found") {
println("Please enter a valid file path")
} else {
// Compile the code
compile(inputFile, peepholeFlag, importsFlag, outputFlag)
}
}
/**
* Option to change out how you want to run the Main.scala file
* This won't print out the full program to terminal but will open and close the file safely
* THIS IS THE FUNCTION TO CALL FOR LABTS
*
* @param args
*/
def compile(pathStr: String, peepholeFlag: Boolean, importsFlag: Boolean, outputFlag: Option[String]): Unit = {
parser.parseFile(pathStr) match {
case parsley.Failure(msg) => {
println(s"\nSyntax Error $msg")
sys.exit(SYNTAX_ERROR_CODE)
}
case parsley.Success(progFromSyntax: WProgram) => {
val fileName = os.FilePath(pathStr).last
val progToAnalyse = ImportHandler.apply().addImports(progFromSyntax, optimiseFlag = importsFlag)
analyse(progToAnalyse) match {
case Left(res) => {
println("Completed Frontend ... Transitioning to backend...")
val (prog, fEnv, mEnv) = res
val codeGenerator = new CodeGenerator()
val codeGen = codeGenerator.generate(prog, fEnv, mEnv)
val optimisation = peepholeFlag match {
case true => Peephole.apply().optimize(codeGen)
case false => codeGen
}
val asmCode = optimisation.map(_.emit).mkString("\n")
val outputFile = s"${fileName.replace(".wacc", ".s")}"
val outputPath = os.pwd / os.SubPath(outputFlag.getOrElse("")) / outputFile
os.makeDir.all(outputPath / os.up)
os.write.over(outputPath, asmCode)
println(s"Assembled code generated at ${outputPath / os.up}")
}
case Right(errs) => {
errs.foreach(err => println(s"\n${err.getMessage(fileName, pathStr)}"))
sys.exit(SEMANTIC_ERROR_CODE)
}
}
}
}
}

View File

@ -0,0 +1,754 @@
package wacc.backend
import wacc.frontend.syntax.ast._
import wacc.frontend.semantic.environment._
import scala.collection.mutable
class CodeGenerator {
private given allocator: RegisterAllocator = new RegisterAllocator()
private given labelGenerator: LabelGenerator = new LabelGenerator()
/**
* Main entry for code generation
* Returns the instructions as a list of instructions
*/
def generate(program: WProgram, globalFuncEnv: GlobalFuncEnv, globalMainEnv: GlobalMainEnv):
List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
given gMainEnv: GlobalMainEnv = globalMainEnv
given gFuncEnv: GlobalFuncEnv = globalFuncEnv
// Generate code for each function declaration
program.funcs.foreach(func => instructions ++= genFunc(func))
// Prepend the main body of the program to current instructions
genMain(program.stats) ++=: instructions
// Prepend the preamble that has all of the string metadata
preamble() ++=: instructions
// Append any standard labels found during generation of program
instructions ++= labelGenerator.genWidgets()
// Add a new line to signify the end of the file
instructions += EOF()
// Return the final list
instructions.toList
}
/**
* Generates the preamble for the assembly file
* Returns a list of the instructions
*/
def preamble(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Add the .data directive
instructions += Directive("data")
// Add all of the string labels with their info
for ((str, (newStr, label)) <- labelGenerator.getStringLiterals()) {
instructions ++= List(
Comment(s"length of $label", ""),
Directive(s"word ${str.size}", " "),
Label(label),
Directive(s"asciz \"$newStr\"", " ")
)
}
// Add remaining preamble info
instructions ++= List(
Directive("align 4"),
Directive("text"),
Directive("global main")
)
instructions.toList
}
/**
* Adds the function label
* Save frame pointer and link register
* Returns a list of the instructions
*/
def funcPrologue(funcLabel: String, _pushCalleeList: List[(String, PhysicalReg)]): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
instructions ++= List(
// Function label
Label(funcLabel),
// Stack setup (pushing frame pointer onto link register)
Comment("push {fp, lr}"),
STP(FP, LR, SP),
// Set the frame pointer to the current stack pointer now that fp has been saved
Comment("Set fp to sp"),
MoveOps(FP, SP)
)
// Push all of the assigned registers in pushCalleeList using STP, where any odd number of registers
// Would yield a storing of it and the zero register
instructions ++= generatePushInstructions(_pushCalleeList, callingFunc = false)
// Compute and reserve space for the remaining local variables (Rounded to 16 Bytes to ensure stack alignment)
val _remainingOffset = ((allocator.getNumVariables() - MAX_CALLEE_REGS + 1) & ~1) * EIGHT_BYTES
if (_remainingOffset > 0) {
val remainingOffset = Immediate(_remainingOffset)
instructions += SUB(SP, SP, remainingOffset)
}
instructions.toList
}
/**
* Function epilogue: pops all registers that were pushed frame pointer and return
* Returns a list of the instructions
*/
def funcEpilogue(_popCalleeList: List[(String, PhysicalReg)]): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Add back the displacement made to sp if the offset is > 0
val _remainingOffset = ((allocator.getNumVariables() - MAX_CALLEE_REGS + 1) & ~1) * EIGHT_BYTES
if (_remainingOffset > 0) {
val remainingOffset = Immediate(_remainingOffset)
instructions += ADD(SP, SP, remainingOffset)
}
// Add the comment for the registers being freed
instructions ++= generatePopInstructions(_popCalleeList.toList, callingFunc = false)
// Pop the the frame pointer and link register
instructions ++= List(
Comment("pop {fp, lr}"),
LDP(FP, LR, SP),
RET(),
EOF()
)
instructions.toList
}
/*
Generates the push instructions for the registers in the pushList
Returns a list of the instructions
*/
def generatePushInstructions(_pushList: List[(String, Allocated)], callingFunc: Boolean):
List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Tranform the _pushList to only contain x registers
val pushList = _pushList.collect {
case (id, reg: Reg) => (id, Register(reg.regNumber, X_REG))
}
if (!pushList.isEmpty) {
instructions += Comment(s"push {${_pushList.map(_._2.toOperand).mkString(", ")}}")
}
pushList.sliding(2, 2).foreach { pair =>
// If there is a pair of registers, then push them together, otherwise, push the singleton
if (pair.length == 2) {
instructions += STP(pair(0)._2, pair(1)._2, SP)
} else {
instructions += STP(pair(0)._2, XZR, SP)
}
}
if (!pushList.isEmpty && callingFunc) {
// Here we add the stack allocations in a fashion similar to stp for registers so that we can
// change stack allocation easier later on
val spillPushList = _pushList.collect { case (id, spill: Spill) => (id, spill)}
// Calculate the amount of stack space to allocate to the stack
val _addedStackSpace = ((spillPushList.size + 1) & ~1) * EIGHT_BYTES
val addedStackSpace = Immediate(_addedStackSpace)
if (_addedStackSpace != 0) {
// Subtract that space from the stack pointer
instructions += SUB(SP, SP, addedStackSpace)
}
// Push each memory location onto the stack
spillPushList.sliding(2, 2).zipWithIndex.foreach { case (pair, index) =>
// Index is incremented by 1
val adjustedIndex = index + 1
// If there is a pair of spills, then push the rhs first and then the lhs
if (pair.length == 2) {
// Store the two allocated values
instructions ++= pushMemToStack(pair(1)._2, 2 * adjustedIndex - 1, _addedStackSpace)
instructions ++= pushMemToStack(pair(0)._2, 2 * adjustedIndex, _addedStackSpace)
// Otherwise, push the zero register and then the other register
} else {
// The offset to save memory to
val saveOffset = Immediate(_addedStackSpace - (2 * adjustedIndex - 1) * EIGHT_BYTES)
// Store the two allocated values
instructions += Store(XZR, SP, Some(saveOffset))
instructions ++= pushMemToStack(pair(0)._2, (2 * adjustedIndex), _addedStackSpace)
}
}
// Order the registers in the same order of how they'd appear in the stack
val stackList = (toStackList(_pushList)).reverse
instructions += Comment(s"Our stack has the order of (front of list shows what #0 is): [${stackList.map(_._2.toOperand).mkString(", ")}]")
// We change the identifier map to point to the stack position of the stored values if we are calling a function
stackList.zipWithIndex.foreach { case ((id, alloc), index) =>
// We create an offset equal to the index of the register times 8 Bytes
val offsetImm = Immediate(EIGHT_BYTES * index)
val stackLocation = Spill(offsetImm, X16)
allocator.addAllocMapping(id, stackLocation)
}
instructions += Comment("Set up X16 as a temporary second base pointer for the caller saved things")
instructions += MoveOps(X16, SP)
}
instructions.toList
}
/*
Generates the pop instructions for the registers in the popList
Returns a list of the instructions
*/
def generatePopInstructions(_popList: List[(String, Allocated)], callingFunc: Boolean):
List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Tranform the _popList to a list only containing x registers
val popList = _popList.collect{case (id, reg: Reg) => (id, Register(reg.regNumber, X_REG))}
// First remove the memory that was stored away if we were calling a function
if (!popList.isEmpty && callingFunc) {
val spillPushList = _popList.collect { case (id, spill: Spill) => (id, spill)}
// Calculate the amount of stack space that was taken away
val _addedStackSpace = ((spillPushList.size + 1) & ~1) * EIGHT_BYTES
val addedStackSpace = Immediate(_addedStackSpace)
if (_addedStackSpace != 0) {
// Add that space back to the stack pointer
instructions += ADD(SP, SP, addedStackSpace)
}
}
if (popList.nonEmpty) {
instructions += Comment(s"pop {${_popList.reverse.map(_._2.toOperand).mkString(", ")}}")
val rest = if (popList.length % 2 == 1) {
instructions += LDP(popList.head._2, XZR, SP) // Pop first register if odd
popList.tail // Process the remaining elements
} else popList
rest.sliding(2, 2).foreach { pair =>
// As they were pushed in order, to maintain the order, pop in reverse taking right then left
instructions += LDP(pair(1)._2, pair(0)._2, SP)
}
}
popList.foreach(pair => allocator.freeCallee(pair._2))
// We revert the identifier map to point back to the original registers if we called a function
// Note: we use _popList instead of popList as that contains the correct mapping for all allocated values
if (callingFunc) {
_popList.foreach { case (id, alloc) =>
allocator.addAllocMapping(id, alloc)
}
}
instructions.toList
}
def toStackList(list: List[(String, Allocated)]): List[(String, Allocated)] = {
// Order the registers in the same order of how they'd appear in the stack (but reversed)
list.sliding(2, 2).flatMap { pair =>
if (pair.length == 2) {
List(pair(1), pair(0))
} else {
List(("", XZR), pair(0))
}
}.toList
}
def pushMemToStack(spill: Spill, num: Int, totalStackSpace: BigInt): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// The offset to save memory to
val saveOffset = Immediate(totalStackSpace - num * EIGHT_BYTES)
// We must load the contents of the location to x17 and then store that into the stack pointer at the right place
instructions += Load(X17, spill)
instructions += Store(X17, SP, Some(saveOffset))
instructions.toList
}
/**
* Generates all of the instructions for the main body block
* Returns a list of the instructions
*/
def genMain(stmts: List[Stmt])
(using gMainEnv: GlobalMainEnv, gFuncEnv: GlobalFuncEnv): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Generate code for each statement in the main block
given pushCalleeList: mutable.ListBuffer[(String, PhysicalReg)] = mutable.ListBuffer.empty
stmts.foreach(stmt => instructions ++= genStmt(stmt)(using pushCalleeList, gMainEnv, gFuncEnv, List.empty))
// Prepend the prologue with the pushCalleeList
funcPrologue("main", pushCalleeList.toList) ++=: instructions
// Store 0 in return reg as the default exit code if program is fine
instructions += MoveOps(X0, Imm0)
// Append the epilogue with the popCalleeList(pushCalleeList reversed)
instructions ++= funcEpilogue(pushCalleeList.reverse.toList)
instructions.toList
}
/**
* Generates all of the instructions for the functions
* Returns a list of the instructions
*/
def genFunc(func: Func)(using gMainEnv: GlobalMainEnv, gFuncEnv: GlobalFuncEnv):
List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// The id of the function used to query the global function environment
val funcKey = func.id.id
// Replace # with an underscore for funcLabel
val funcLabel = func.id.id.replaceAll("#", "_")
// Every function has a different push list to see stored variables
given pushCalleeList: mutable.ListBuffer[(String, PhysicalReg)] = mutable.ListBuffer.empty
// Save the state of the allocator before setting up function
val prevState = allocator.saveState()
// Retrieve the parameters of the function
val funcParams = gFuncEnv.lookup(funcKey).getOrElse(
throw new RuntimeException(s"Function $funcKey not found in the environment")
).params
given funcRegList: mutable.ListBuffer[(String, Allocated)] = mutable.ListBuffer.empty
funcParams.zipWithIndex.foreach{case (param, index) =>
// Add the parameter to the global environment
gMainEnv.add(param)
// Add the parameter to the function register list
val bitMode = param._type match {
case (IntType | CharType | BoolType) => W_REG
case _ => X_REG
}
val paramAlloc = index match {
// If the index is between 0 and 7, then create a parameter register for it
case paramReg if paramReg < MAX_PARAM_REGS => Register(s"$paramReg", bitMode)
// Otherwise, create a space in memory for it
case nonReg =>
// Calculate the stack offset for the non register parameter
val offset = (nonReg - MAX_PARAM_REGS) * EIGHT_BYTES
// Adjust the offset by 16 Bytes to account for fp lr push onto stack
val adjustedOffset = BigInt(offset) + SIXTEEN_BYTES
val stackOffset = Immediate(adjustedOffset)
// Create a stack allocation with FP as the base pointer as this is how to access the
// variable in the function
Spill(stackOffset, FP)
}
// Adding the param's name with its arg register to the funcRegList
val listEntryFunc = (param.uniqueName, paramAlloc)
funcRegList += listEntryFunc
// Also add the allocation to the callee push list
val listEntryCallee = (param.uniqueName, paramAlloc)
listEntryCallee match {
case (identifier, physicalReg: PhysicalReg) =>
val entry = (identifier, physicalReg)
pushCalleeList += entry
case _ => ()
}
// Create the linkage between an id and its paramAlloc
allocator.addAllocMapping(param.uniqueName, paramAlloc)
}
// Generate code for each statement in the function block
func.body.foreach(stmt => instructions ++=
genStmt(stmt)(using pushCalleeList, gMainEnv, gFuncEnv, funcRegList.toList))
// Prepend the prologue with the pushCalleeList
funcPrologue(funcLabel, pushCalleeList.toList) ++=: instructions
// Replace all instances of a comment called "EPILOGUE" with the epilogue of the function
// At this point, we have all callee registers known to us
val finalInstructions = instructions.flatMap {
case Comment("EPILOGUE", "") => funcEpilogue(pushCalleeList.reverse.toList)
case other => List(other)
}
// Restore the state to allow for callees to run as intended
allocator.restoreState(prevState)
finalInstructions.toList
}
/**
* Generates code for a statement
* Returns a list of instructions
*/
def genStmt(stmt: Stmt)
(using pushCalleeList: mutable.ListBuffer[(String, PhysicalReg)], gMainEnv: GlobalMainEnv,
gFuncEnv: GlobalFuncEnv, funcRegList: List[(String, Allocated)]): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
stmt match {
case Skip(_) =>
// We don't have any instructions here
instructions += Comment("SKIP")
instructions.toList
case Declare(_, t, ident, rhs) =>
instructions += Comment("DECLARE BEGINS HERE")
// Evaluate the right-hand side
val (rhsInstr, rhsAlloc) = genRhs(rhs, this)
instructions ++= rhsInstr
// Allocate a callee allocation for this ident
val calleeAlloc = rhsType(rhs)(using gMainEnv, gFuncEnv) match {
// Ints, Chars, and Bools all use W_REG register bit-modes
case (IntType | CharType | BoolType) => allocator.allocateCallee(W_REG)
case _ => allocator.allocateCallee(X_REG)
}
// Add the allocation to pushCalleeList (Will create STP instructions to store to the stack)
val listEntryCallee = (ident.id, calleeAlloc)
listEntryCallee match {
case (identity, physicalReg: PhysicalReg) =>
val entry = (identity, physicalReg)
pushCalleeList += entry
case _ => ()
}
// Store the value of the rhs to the lhs
// Check on the whether the operand is on the stack or a register
instructions += Move(calleeAlloc, rhsAlloc)
// Add the allocation to the variable's map of vars to maps
allocator.addAllocMapping(ident.id, calleeAlloc)
// Free the temp register
allocator.freeTemp(rhsAlloc)
rhs match {
case Call(_, id, args) =>
// If the rhs is a function call, free all of the callee registers
instructions ++= generatePopInstructions(funcRegList.reverse, callingFunc = true)
case _ =>
}
instructions.toList
case Assign(_, lhs, rhs) =>
instructions += Comment("ASSIGN BEGINS HERE")
// Create w/x8 register using allocTemp as this is at the front of the temp pool (only for ArrayElems)
val bufferReg8 = lhsType(lhs)(using gMainEnv) match {
case (IntType | CharType | BoolType) => allocator.allocateTemp(W_REG)
case _ => allocator.allocateTemp(X_REG)
}
// Only free the lhs if it's not an arrayElem
lhs match {
case ArrayElem(_, _, _) => ()
case _ => allocator.freeTemp(bufferReg8)
}
val (rhsInstr, rhsAlloc) = genRhs(rhs, this)
instructions ++= rhsInstr
val newRhsAlloc = rhs match {
case Call(_, id, args) =>
// Set up the function scrath and return registers according to the lhs type
val (functionScratchReg, returnReg) = lhsType(lhs)(using gMainEnv) match {
case (IntType | CharType | BoolType) => (W16, W0)
case _ => (X16, X0)
}
// Move the output to a function scratch register
instructions += MOV(functionScratchReg, returnReg)
// Free all of the callee registers
instructions ++= generatePopInstructions(funcRegList.reverse, callingFunc = true)
functionScratchReg
case _ => rhsAlloc
}
// Get the lhs to move the rhs into
val (lhsInstr, lhsAlloc) = genLhsForUsing(lhs, instructions)
instructions ++= lhsInstr
// Make the lhsMove now
// Performs the moving of the rhs to the lhs which handles edge cases (ArrayElem, PairElem)
instructions ++= lhsMove(lhs, lhsAlloc, newRhsAlloc, bufferReg8)
// Free all temp allocations
allocator.freeTemp(newRhsAlloc)
allocator.freeTemp(rhsAlloc)
allocator.freeTemp(lhsAlloc)
allocator.freeTemp(bufferReg8)
instructions.toList
case pStmt: PrintBase =>
instructions += Comment("PRINT BEGINS HERE")
// Push registers onto the stack to save their value
instructions ++= generatePushInstructions(funcRegList, true)
val (exprInstr, exprAlloc) = genExpr(pStmt.expr)
// The bit mode of the return register (x/w 0) will be determined by the expression's type
val returnReg = exprType(pStmt.expr)(using gMainEnv) match {
case (IntType | CharType | BoolType) => W0
case _ => X0
}
instructions ++= exprInstr
instructions += MoveOps(returnReg, exprAlloc)
// Check for the type of the expression here
exprType(pStmt.expr)(using gMainEnv) match {
case BoolType =>
instructions += BL(_printb)
case CharType =>
instructions += BL(_printc)
case IntType =>
instructions += BL(_printi)
// Can print Char arrays as a string
case (StrType | ArrayType(CharType)) =>
instructions += BL(_prints)
// All of the other types require _printp
case _ =>
instructions += BL(_printp)
}
// Add the branch to _println for a Println statement
if (pStmt.isInstanceOf[Println]) {
instructions += BL(_println)
}
// Free allocation
allocator.freeTemp(exprAlloc)
// Pop off the saved registers from the stack
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
instructions.toList
case Read(_, lhs) =>
instructions += Comment("READ BEGINS HERE")
// Save argument registers
instructions ++= generatePushInstructions(funcRegList, callingFunc = true)
// Create w8 register using allocTemp as this is at the front of the temp pool
val bufferReg8 = allocator.allocateTemp(W_REG)
val (lhsInstr, lhsAlloc) = genLhs(lhs)
// Ints and Chars will always use W_REG bit-mode registers
instructions ++= lhsInstr
lhs match {
case _: PairElem =>
instructions ++= List(
Compare(lhsAlloc, Imm0),
BCond(EQ, _errNull),
Load(X0, lhsAlloc)
)
case _ => instructions += Move(W0, lhsAlloc)
}
// Check for the type of the lhs here
lhsType(lhs)(using gMainEnv) match {
case CharType =>
instructions += BL(_readc)
case IntType =>
instructions += BL(_readi)
case _ =>
instructions += Comment("UNKNOWN TYPE IN READ")
}
// Finish register manipulation
// Depending on the type of the lhs, invoke the correct next move
// Ints and Chars will always use W_REG bit-mode registers
instructions += MoveOps(W16, W0)
// Free the lhsAlloc before reloading lhs
allocator.freeTemp(lhsAlloc)
// Reload the lhs so that we get the correct address
val (reloadInstrs, reloadLhs) = genLhsForUsing(lhs, instructions)
instructions ++= reloadInstrs
instructions ++= lhsMove(lhs, reloadLhs, W16, bufferReg8)
allocator.freeTemp(bufferReg8)
allocator.freeTemp(reloadLhs)
// Pop off argument registers
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
instructions.toList
case Exit(_, expr) =>
instructions += Comment("EXIT BEGINS HERE")
// Push registers onto the stack to save their value
instructions ++= generatePushInstructions(funcRegList, true)
// For an exit statement, evaluate the expression and call the exit routine
val (exprInstr, exprAlloc) = genExpr(expr)
instructions ++= exprInstr
// The expression is always an int so we use the W_REG bit-mode
instructions += MoveOps(W0, exprAlloc)
instructions += BL(exit)
// Pop off argument registers
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
allocator.freeTemp(exprAlloc)
instructions.toList
case Return(_, expr) =>
instructions += Comment("RETURN BEGINS HERE")
val (exprInstr, exprAlloc) = genExpr(expr)
// The bit mode of the return register (x/w 0) will be determined by the expression's type
val returnReg = exprType(expr)(using gMainEnv) match {
case (IntType | CharType | BoolType) => W0
case _ => X0
}
instructions ++= exprInstr
instructions += MoveOps(returnReg, exprAlloc)
// We mark this section as needing to have an epilogue
// It is filled in later in the genFunc method
instructions += Comment("EPILOGUE", "")
allocator.freeTemp(exprAlloc)
instructions.toList
case If(_, cond, thenStats, elseStats) =>
instructions += Comment("IF BEGINS HERE")
// Generate fresh labels for then and end case. Else case would go right after label
val thenLabel = labelGenerator.freshLabel()
val endLabel = labelGenerator.freshLabel()
// Generate code for the condition, ending with a Compare instruction
val (condInstr, branchFlag) = genCond(cond)
instructions ++= condInstr
// Branch to the then block if the condition is true
instructions += BCond(branchFlag, thenLabel)
// Continue onto generating code for the else block
elseStats.foreach(stmt => instructions ++= genStmt(stmt))
// Branch to the end label after finishing else block
instructions += B(endLabel)
// Label and instructions for then
instructions += Label(thenLabel)
thenStats.foreach(stmt => instructions ++= genStmt(stmt))
// End of the if statement
instructions += Label(endLabel)
instructions.toList
case While(_, cond, body) =>
instructions += Comment("WHILE BEGINS HERE")
// Generate fresh labels for the loop
val condLabel = labelGenerator.freshLabel()
val bodyLabel = labelGenerator.freshLabel()
// Branch to the condition check
instructions += B(condLabel)
// Generate code for the loop body
instructions += Label(bodyLabel)
body.foreach(stmt => instructions ++= genStmt(stmt))
// After the body, check the condition again
instructions += B(condLabel)
// Generate code for the condition check
instructions += Label(condLabel)
val (condInstr, condFlag) = genCond(cond)
instructions ++= condInstr
instructions += BCond(condFlag, bodyLabel)
// End of the loop just continues
instructions.toList
// Generates each statement in the block
case Block(_, stmts) => stmts.flatMap(genStmt)
case Free(_, expr: Rhs) =>
val (exprInstr, exprAlloc) = genRhs(expr, this)
// Save argument registers
instructions ++= generatePushInstructions(funcRegList, callingFunc = true)
instructions ++= exprInstr
exprType(expr)(using gMainEnv) match {
case _: ArrayType =>
instructions ++= List(
Comment("array pointers are shifted forward by 4 bytes, " +
"so correct it back to original pointer before free"),
SUB(exprAlloc, exprAlloc, Imm4),
MoveOps(X0, exprAlloc),
BL(free)
)
case _: (PairType | PairPlaceholder.type | Any) =>
instructions += MoveOps(X0, exprAlloc)
instructions += BL(_freepair)
}
// Pop off argument registers
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
allocator.freeTemp(exprAlloc)
instructions.toList
}
}
}

View File

@ -0,0 +1,449 @@
package wacc.backend
import wacc.frontend.syntax.ast._
import wacc.frontend.semantic.environment._
import scala.collection.mutable
/**
* Generates code for an expression
* Returns a tuple: (instructions, allocated operand holding the result)
*/
def genExpr(expr: Expr)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], PhysicalReg) = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// The bit-mode of the temp register depends on the type of the expression
val tempAlloc = exprType(expr) match {
case (IntType | CharType | BoolType) => allocator.allocateTemp(W_REG)
case _ => allocator.allocateTemp(X_REG)
}
expr match {
case IntLiter(_, value) =>
val valueImm = Immediate(value)
instructions += MoveOps(tempAlloc, valueImm)
(instructions.toList, tempAlloc)
case BoolLiter(_, _value) =>
val value = if (_value) 1 else 0
val valueImm = Immediate(value)
instructions += MoveOps(tempAlloc, valueImm)
(instructions.toList, tempAlloc)
case CharLiter(_, value) =>
val valueImm = Immediate(value.toInt)
instructions += MoveOps(tempAlloc, valueImm)
(instructions.toList, tempAlloc)
case StrLiter(_, value) =>
val strLabel = labelGenerator.freshLabel(value)
instructions += ADRP(tempAlloc, strLabel)
instructions += ADD(tempAlloc, tempAlloc, LabelOp(strLabel))
(instructions.toList, tempAlloc)
case PairLiter(_) =>
instructions += MoveOps(tempAlloc, Imm0)
(instructions.toList, tempAlloc)
case Ident(_, id) =>
val _idAlloc = allocator.getAllocMapping(id)
// Make the bit-mode of the idAlloc the same as the bitmode of the tempAlloc
val idAlloc = _idAlloc match {
case reg: PhysicalReg => tempAlloc match {
case PhysicalReg(_, X_REG, _) => PhysicalReg(reg.regNumber, X_REG, reg.pool)
case PhysicalReg(_, other, _) => PhysicalReg(reg.regNumber, W_REG, reg.pool)
}
case spill => spill
}
instructions += Move(tempAlloc, idAlloc)
(instructions.toList, tempAlloc)
case Add(lhs, rhs) =>
allocator.freeTemp(tempAlloc)
genArithOp(lhs, rhs, ADDS.apply)
case Sub(lhs, rhs) =>
allocator.freeTemp(tempAlloc)
genArithOp(lhs, rhs, SUBS.apply)
case Mul(lhs, rhs) =>
// Restore the register to make efficient use of registers
allocator.restore()
val (lhsInstr, lhsAlloc) = genExpr(lhs)
val (rhsInstr, rhsAlloc) = genExpr(rhs)
// Need an x bit-mode register for detecting overflow errors
val xLhs = PhysicalReg(lhsAlloc.regNumber, X_REG, lhsAlloc.pool)
instructions ++= lhsInstr
instructions ++= rhsInstr
instructions ++= List(
SMULL(xLhs, lhsAlloc, rhsAlloc),
Compare(xLhs, lhsAlloc ,Some(SXTW())),
BCond(NE, _errOverflow)
)
allocator.freeTemp(rhsAlloc)
(instructions.toList, lhsAlloc)
// Handles Division and Modulo operations
case divOp: DivOp =>
val (lhs, rhs) = (divOp.lhs, divOp.rhs)
allocator.restore()
val (lhsInstr, lhsAlloc) = genExpr(lhs)
val (rhsInstr, rhsAlloc) = genExpr(rhs)
instructions ++= lhsInstr
instructions ++= rhsInstr
instructions ++= List(
Compare(rhsAlloc, Imm0),
BCond(EQ, _errDivZero),
)
divOp match {
case _: Div => instructions += SDIV(lhsAlloc, lhsAlloc, rhsAlloc)
case _: Mod =>
instructions += SDIV(W17, lhsAlloc, rhsAlloc)
instructions += MSUB(lhsAlloc, W17, rhsAlloc, lhsAlloc)
}
allocator.freeTemp(rhsAlloc)
(instructions.toList, lhsAlloc)
case And(lhs, rhs) => genBoolOp(lhs, rhs, NE)
case Or(lhs, rhs) => genBoolOp(lhs, rhs, EQ)
case compOp: CompOp => genComparison(compOp.lhs, compOp.rhs, compOp)
case chr: Chr =>
val (chrInstr, chrAlloc) = genUnaryOp(chr.expr, chr)
val xChr = chrAlloc match {
case PhysicalReg(regNumber, _ , pool) => PhysicalReg(regNumber, X_REG, pool)
}
instructions ++= chrInstr
val testVal = -128
val tstImm = Immediate(testVal)
instructions ++= List(
TST(chrAlloc, tstImm),
CSEL(X1, xChr, X1, NE),
BCond(NE, _errBadChar)
)
(instructions.toList, chrAlloc)
case neg: Neg =>
val (negInstr, negAlloc) = genUnaryOp(neg.expr, neg)
instructions ++= negInstr
instructions += BCond(VS, _errOverflow)
(instructions.toList, negAlloc)
case unOp: UnOp => genUnaryOp(unOp.expr, unOp)
case IfExpr(pos, cond, thenBranch, elseBranch) =>
val thenLabel = labelGenerator.freshLabel()
val endLabel = labelGenerator.freshLabel()
val (condInstr, branchFlag) = genCond(cond)
instructions ++= condInstr
// Branch to the then block if the condition is true
instructions += BCond(branchFlag, thenLabel)
val resultReg = exprType(thenBranch) match {
case (IntType | CharType | BoolType) => allocator.allocateTemp(W_REG)
case _ => allocator.allocateTemp(X_REG)
}
val (elseInstr, elseAlloc) = genExpr(elseBranch)
instructions ++= elseInstr
instructions += Move(resultReg, elseAlloc)
instructions += B(endLabel)
instructions += Label(thenLabel)
val (thenInstr, thenAlloc) = genExpr(thenBranch)
instructions ++= thenInstr
instructions += Move(resultReg, thenAlloc)
instructions += B(endLabel)
instructions += Label(endLabel)
(instructions.toList, resultReg)
case PairElem(_, selector, lhs: Expr) =>
allocator.restore()
val (exprInstr, exprAlloc) = genExpr(lhs)
val xExpr = PhysicalReg(exprAlloc.regNumber, X_REG, exprAlloc.pool)
instructions ++= exprInstr
instructions += Compare(exprAlloc, Imm0)
instructions += BCond(EQ, _errNull)
selector match {
case Fst => instructions += LDR(xExpr, exprAlloc)
case Snd => instructions += LDR(xExpr, exprAlloc, Some(Imm8))
}
// We use knowledge of the bitmode from tempAlloc to get the right bitmode for return
val resultReg = PhysicalReg(xExpr.regNumber, tempAlloc.bitMode, xExpr.pool)
(instructions.toList, resultReg)
case ArrayElem(_, id, indices) =>
// Get the array type of the identifier we are accessing
var curType = exprType(id) match {
case ArrayType(innerType) => innerType
case _type => _type
}
val xTemp = PhysicalReg(tempAlloc.regNumber, X_REG, tempAlloc.pool)
for ((indexExpr, i) <- indices.zipWithIndex) {
instructions += Comment(s"curType = $curType, indexExpr = $indexExpr, i = $i")
// If i != last index, then we have to store our previous value (so must be an x bit-value)
if (i != 0) {
instructions += Comment(s"push {${xTemp}}")
instructions += STP(xTemp, XZR, SP)
}
// Generate code for the index expression
val (indexInstr, indexAlloc) = genExpr(indexExpr)
instructions ++= indexInstr
instructions += MoveOps(W17, indexAlloc)
allocator.freeTemp(indexAlloc)
// If we're not on the first iteration, pop x7 back
if (i != 0) {
instructions += Comment("pop {x7}")
instructions += LDP(X7, XZR, SP)
} else {
val (idInstr, idAlloc) = genExpr(id)
instructions ++= idInstr
instructions += MoveOps(X7, idAlloc)
allocator.freeTemp(idAlloc)
}
// Call the appropriate array-load function
instructions += BL(_arrLoad(s"${sizeof(curType)}"))
curType match {
case ArrayType(_) => instructions += MoveOps(xTemp, X7)
case _ =>
val reg7 = tempAlloc match {
case reg: Reg => reg.bitMode match {
case W_REG => W7
case X_REG => X7
// This case shouldn't happen (Maybe call it a random register)
case _ => Register("Should know type of reg7", ' ')
}
}
instructions += MoveOps(tempAlloc, reg7)
}
curType match {
case ArrayType(innerType) => curType = innerType
case _ => ()
}
}
(instructions.toList, tempAlloc)
}
}
/**
* Generates assembly for a binary operation
*/
def genBinaryOp(lhs: Expr, rhs: Expr, op: (Reg, Operand, Operand) => Instruction)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], PhysicalReg) = {
// Allocate a w register as all binary operators operate on types of 4 bytes
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Restore the register to make efficient use of registers
val (lhsAlloc, rhsAlloc) =
if (exprWeight(lhs) > exprWeight(rhs)) {
getEvaluatedExpr(lhs, rhs, instructions, true)
} else {
getEvaluatedExpr(rhs, lhs, instructions, false)
}
instructions += op(lhsAlloc, lhsAlloc, rhsAlloc)
allocator.freeTemp(rhsAlloc)
(instructions.toList, lhsAlloc)
}
/**
* Evaluates an expression by taking the lhs and rhs
* Uses the takeLeft parameter to dictate the order that we return the allocations with
*/
def getEvaluatedExpr(lhs: Expr, rhs: Expr, instructions: mutable.ListBuffer[Instruction], takeLeft: Boolean)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(PhysicalReg, PhysicalReg) = {
val (lhsInstr, lhsAlloc) = genExpr(lhs)
val (rhsInstr, rhsAlloc) = genExpr(rhs)
instructions ++= lhsInstr
instructions ++= rhsInstr
if (takeLeft) {
(lhsAlloc, rhsAlloc)
} else {
(rhsAlloc, lhsAlloc)
}
}
def exprWeight(expr: Expr): Int = expr match {
case binOp: BinOp => Math.max(exprWeight(binOp.lhs), exprWeight(binOp.rhs)) + 1
case _ => 0
}
def genArithOp(lhs: Expr, rhs: Expr, op: (Reg, Operand, Operand) => Instruction)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], PhysicalReg) = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val (addsInstr, addsAlloc) = genBinaryOp(lhs, rhs, op)
instructions ++= addsInstr
instructions += BCond(VS, _errOverflow)
(instructions.toList, addsAlloc)
}
/**
* Generates assembly for a unary operator
*/
def genUnaryOp(expr: Expr, unOp: UnOp)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], PhysicalReg) = {
// Allocate a w register as all unary operators operate on types of 4 bytes
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val op = unOp match {
case _: Not => EOR.apply(_, _, Imm1)
case _: Len => LDUR.apply
case _: Neg => NEGS.apply
case _: (Ord | Chr) => MoveOps.apply
}
// Restores a temp register as they will collapse onto each other
allocator.restore()
val (exprInstr, exprAlloc) = genExpr(expr)
// Free the temp register as it will be the same as the temp
allocator.freeTemp(exprAlloc)
val tempAlloc = allocator.allocateTemp(W_REG)
instructions ++= exprInstr
instructions += op(tempAlloc, exprAlloc)
(instructions.toList, tempAlloc)
}
/**
* Generates assembly for a comparison expression
*/
def genComparison(lhs: Expr, rhs: Expr, compOp: CompOp)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], PhysicalReg) = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val cond = compOp match {
case Gt(lhs, rhs) => GT
case Geq(lhs, rhs) => GE
case Lt(lhs, rhs) => LT
case Leq(lhs, rhs) => LE
case Eq(lhs, rhs) => EQ
case Neq(lhs, rhs) => NE
}
// Restore a temp register to keep register use efficient
allocator.restore()
val (lhsInstr, lhsAlloc) = genExpr(lhs)
val (rhsInstr, rhsAlloc) = genExpr(rhs)
// We use a w register as we are dealing with booleans
val wLhs = PhysicalReg(lhsAlloc.regNumber, W_REG, lhsAlloc.pool)
instructions ++= lhsInstr
instructions ++= rhsInstr
instructions += Compare(lhsAlloc, rhsAlloc)
instructions += CSET(wLhs, cond)
allocator.freeTemp(rhsAlloc)
(instructions.toList, wLhs)
}
def genBoolOp(lhs: Expr, rhs: Expr, cond: Flag)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], PhysicalReg) = {
allocator.restore()
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val (lhsInstr, lhsAlloc) = genExpr(lhs)
val boolLabel = labelGenerator.freshLabel()
instructions ++= lhsInstr
instructions ++= List(
Compare(lhsAlloc, Imm1),
BCond(cond, boolLabel)
)
allocator.freeTemp(lhsAlloc)
val (rhsInstr, rhsAlloc) = genExpr(rhs)
instructions ++= rhsInstr
instructions ++= List(
Compare(rhsAlloc, Imm1),
Label(boolLabel),
CSET(rhsAlloc, EQ)
)
(instructions.toList, rhsAlloc)
}
/**
* Generates code for a condition
* Function guarantees that Compare instruction is generated, to be used by caller function
* Returns a tuple: (instructions, branch condition)
*/
def genCond(cond: Expr)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], Flag) = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
cond match {
case Neq(lhs, rhs) => (generateOpCond(lhs, rhs), NE)
case Eq(lhs, rhs) => (generateOpCond(lhs, rhs), EQ)
case Not(expr) => (generateOpCond(expr, BoolLiter(IGNORE, true)), NE)
case Gt(lhs, rhs) => (generateOpCond(lhs, rhs), GT)
case Lt(lhs, rhs) => (generateOpCond(lhs, rhs), LT)
case Geq(lhs, rhs) => (generateOpCond(lhs, rhs), GE)
case Leq(lhs, rhs) => (generateOpCond(lhs, rhs), LE)
case expr: (Ident | And | Or) =>
val (exprInstr, exprAlloc) = genExpr(expr)
instructions ++= exprInstr
instructions += Compare(exprAlloc, Imm1)
allocator.freeTemp(exprAlloc)
(instructions.toList, EQ)
case x: BoolLiter =>
val (lhsInstr, lhsAlloc) = genExpr(cond)
instructions ++= lhsInstr
instructions += Compare(lhsAlloc, Imm1)
allocator.freeTemp(lhsAlloc)
(instructions.toList, EQ)
case _ => (List(Comment("CONDITION NOT CREATED")), FAILEDFLAG)
}
}
/**
* generates instructions for a comparison expression
*/
def generateOpCond(lhs: Expr, rhs: Expr)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val (lhsInstr, lhsAlloc) = genExpr(lhs)
val (rhsInstr, rhsAlloc) = genExpr(rhs)
instructions ++= lhsInstr
instructions ++= rhsInstr
instructions += Compare(lhsAlloc, rhsAlloc)
allocator.freeTemp(lhsAlloc)
allocator.freeTemp(rhsAlloc)
instructions.toList
}

View File

@ -0,0 +1,471 @@
package wacc.backend
import scala.collection.mutable
val MIN_IMM_VAL = -256
val MAX_IMM_VAL = 255
/**
* A representation for AArch64 assembly instructions.
* Each instruction is wrapped in a case class that knows how to emit its string.
*/
sealed trait Instruction {
def emit: String
}
// --------------------------
// Basic instruction classes
// --------------------------
// Write a comment with default indentation
case class Comment(name: String, tabs: String = " ") extends Instruction {
override def emit: String = s"$tabs// $name"
}
case class Directive(name: String, tabs: String = "") extends Instruction {
override def emit: String = s"$tabs.$name"
}
case class Label(name: String) extends Instruction {
override def emit: String = s"$name:"
}
// --------------------------
// Datamovement instructions
// --------------------------
// MOV now takes a destination register and a source operand.
case class MOV(dest: Reg, src: Operand) extends Instruction {
override def emit: String = s" mov ${dest.toOperand}, ${src.toOperand}"
}
case class MoveOps(dest: Reg, src: Operand) extends Instruction {
override def emit: String = getInstrs().map(_.emit).mkString("\n")
def getInstrs(): List[Instruction] =
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val bitMask = 0xFFFF
val shift16 = 16
src match {
case _: Reg => instructions += MOV(dest, src)
case _: LabelOp => () // Label Op shouldn't be in the move (at least not for now?)
case Immediate(value) =>
val valInt = value.toInt
valInt match {
// Handles immediate values that are greater than 0xFFFF and breals them into two parts
case _ if (valInt >= 0 && valInt < bitMask) => instructions += MOV(dest, src)
case _ =>
val immOne = Immediate((valInt & bitMask))
val immTwo = Immediate(((valInt >> shift16) & bitMask))
instructions += MOV(dest, immOne)
instructions += MOVK(dest, immTwo, LSL(Imm16))
}
}
instructions.toList
}
case class MOVK(dest: Reg, src: Operand, shift: Shift) extends Instruction {
override def emit: String = s" movk ${dest.toOperand}, ${src.toOperand}, ${shift.emit}"
}
// --------------------------
// Arithmetic instructions
// --------------------------
case class ADD(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" add ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
case class ADDS(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" adds ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
case class SUB(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" sub ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
case class SUBS(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" subs ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
case class SMULL(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" smull ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
case class SDIV(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" sdiv ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
case class MSUB(dest: Reg, op1: Operand, op2: Operand, op3: Operand) extends Instruction {
override def emit: String = s" msub ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}, ${op3.toOperand}"
}
case class NEGS(dest: Reg, src: Operand) extends Instruction {
override def emit: String = s" negs ${dest.toOperand}, ${src.toOperand}"
}
// --------------------------
// Bitwise instructions
// --------------------------
case class EOR(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
override def emit: String = s" eor ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
}
// --------------------------
// Address instructions
// --------------------------
case class ADRP(dest: Reg, label: String) extends Instruction {
override def emit: String = s" adrp ${dest.toOperand}, $label"
}
case class ADR(dest: Reg, label: String) extends Instruction {
override def emit: String = s" adr ${dest.toOperand}, $label"
}
// --------------------------
// Comparison and flag-setting
// --------------------------
case class SXTW() extends Instruction {
override def emit: String = "sxtw"
}
case class CMP(op1: Operand, op2: Operand, _sxtw: Option[SXTW] = None) extends Instruction {
override def emit: String = _sxtw match {
case None => s" cmp ${op1.toOperand}, ${op2.toOperand}"
case Some(sxtw) => s" cmp ${op1.toOperand}, ${op2.toOperand}, ${sxtw.emit}"
}
}
case class CSET(dest: Reg, condition: Flag) extends Instruction {
override def emit: String = s" cset ${dest.toOperand}, ${condition.emit}"
}
case class TST(reg: Reg, op: Operand) extends Instruction {
override def emit: String = s" tst ${reg.toOperand}, ${op.toOperand}"
}
// --------------------------
// Branch instructions
// --------------------------
case class B(label: String) extends Instruction {
override def emit: String = s" b $label"
}
case class BL(label: String, _ignore: String = "") extends Instruction {
override def emit: String = s" bl $label"
}
object BL {
def apply(label: String)(using labelGenerator: LabelGenerator): BL = {
labelGenerator.addWidget(label)
new BL(label)
}
}
case class BCond(flag: Flag, label: String, _ignore: String = "") extends Instruction {
override def emit: String = s" b.${flag.emit} $label"
}
object BCond {
def apply(flag: Flag, label: String)(using labelGenerator: LabelGenerator): BCond = {
labelGenerator.addWidget(label)
new BCond(flag, label)
}
}
case class CBZ(reg: Reg, label: String) extends Instruction {
override def emit: String = s" cbz ${reg.toOperand}, $label"
}
case class CSEL(dest: Reg, src1: Reg, src2: Reg, condition: Flag) extends Instruction {
override def emit: String = s" csel ${dest.toOperand}, ${src1.toOperand}, ${src2.toOperand}, ${condition.emit}"
}
// --------------------------
// Load/store instructions
// --------------------------
// A shift operation
sealed trait Shift {
def emit: String
}
case class LSL(amount: Immediate) extends Shift {
override def emit: String = s"lsl ${amount.toOperand}"
}
// LDR
case class LDR(dest: Reg, base: Reg, offset: Option[Operand] = None, shift: Option[Shift] = None) extends Instruction {
override def emit: String = (offset, shift) match {
case (None, None) => s" ldr ${dest.toOperand}, [${base.toOperand}]"
case (Some(off), None) => s" ldr ${dest.toOperand}, [${base.toOperand}, ${off.toOperand}]"
case (Some(off), Some(sft)) => s" ldr ${dest.toOperand}, [${base.toOperand}, ${off.toOperand}, ${sft.emit}]"
case _ => s" ldr ${dest.toOperand}, [${base.toOperand}]"
}
}
// LDUR
case class LDUR(dest: Reg, base: Operand) extends Instruction {
override def emit: String = s" ldur ${dest.toOperand}, [${base.toOperand}, #-4]"
}
// LDP
case class LDP(dest1: Reg, dest2: Reg, base: Reg) extends Instruction {
override def emit: String = s" ldp ${dest1.toOperand}, ${dest2.toOperand}, [${base.toOperand}], #16"
}
// LDRB
case class LDRB(dest: Reg, base: Reg, offset: Operand) extends Instruction {
override def emit: String = s" ldrb ${dest.toOperand}, [${base.toOperand}, ${offset.toOperand}]"
}
// STP
case class STP(src1: Reg, src2: Reg, base: Reg, offset: Option[Operand] = None) extends Instruction {
override def emit: String = offset match {
case None => s" stp ${src1.toOperand}, ${src2.toOperand}, [${base.toOperand}, #-16]!"
case Some(off) => s" stp ${src1.toOperand}, ${src2.toOperand}, [${base.toOperand}, ${off.toOperand}]!"
}
}
// STR
case class STR(src: Reg, base: Reg, offset: Option[Operand] = None, shift: Option[Shift] = None) extends Instruction {
override def emit: String = (offset, shift) match {
case (None, None) => s" str ${src.toOperand}, [${base.toOperand}]"
case (Some(off), None) => s" str ${src.toOperand}, [${base.toOperand}, ${off.toOperand}]"
case (Some(off), Some(sft)) => s" str ${src.toOperand}, [${base.toOperand}, ${off.toOperand}, ${sft.emit}]"
case _ => s" str ${src.toOperand}, [${base.toOperand}]"
}
}
// STRB
case class STRB(src: Operand, base: Reg, offset: Operand) extends Instruction {
override def emit: String = s" strb ${src.toOperand}, [${base.toOperand}, ${offset.toOperand}]"
}
// STUR
case class STUR(src: Operand, base: Reg) extends Instruction {
override def emit: String = s" stur ${src.toOperand}, [${base.toOperand}, #-4]"
}
case class RET() extends Instruction {
override def emit: String = " ret"
}
case class EOF() extends Instruction {
override def emit: String = ""
}
// --------------------------
// Bool flags
// --------------------------
sealed trait Flag extends Instruction
case object NE extends Flag {
override def emit: String = "ne"
}
case object EQ extends Flag {
override def emit: String = "eq"
}
case object GT extends Flag {
override def emit: String = "gt"
}
case object LT extends Flag {
override def emit: String = "lt"
}
case object GE extends Flag {
override def emit: String = "ge"
}
case object LE extends Flag {
override def emit: String = "le"
}
case object VS extends Flag {
override def emit: String = "vs"
}
case object FAILEDFLAG extends Flag {
override def emit: String = "failed"
}
// ----------------------------------------------------
// General Instruction Classes
// ----------------------------------------------------
// Contains all classes which perform the generic movement
// (To allow for all allocated types to work)
/**
* A class which invokes the correct moving instruction for
*
*/
case class Move(dest: Allocated, src: Allocated)
(using allocator: RegisterAllocator) extends Instruction {
// Emits the correct store instruction depending on the contents of src and dest
override def emit: String = getInstrs().map(_.emit).mkString("\n")
// Functions pertaining instruction allocation
def getInstrs(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
dest match {
// Check on the whether the operand is on the stack or a register
// Using either MOV or LDR instructions
case regDest: Reg => src match {
// Move the Src operand to the Dest register
case opSrc: (Reg | Immediate) => instructions += MoveOps(regDest, opSrc)
// Load the Src register to the Dest register
case Spill(offsetSrc, basePointer) => instructions += Load(regDest, basePointer, Some(offsetSrc))
}
// Using STR instructions
case Spill(offsetDest, basePointer) => src match {
case regSrc: (Reg) =>
instructions += Store(regSrc, basePointer, Some(offsetDest))
// TODO: Might need to add logic for tempReg Spill creation so that there's always a free
// temp alloc to move things to
// We will Load then store here
case Spill(offsetSrc, basePointer) =>
// Create a temp x bit-mode register
val tempAlloc = allocator.allocateTemp(X_REG)
tempAlloc match {
case reg: Reg =>
instructions += Load(reg, basePointer, Some(offsetDest))
instructions += Store(reg, basePointer, Some(offsetSrc))
}
allocator.freeTemp(tempAlloc) // Free it (location doesn't matter)
}
}
instructions.toList
}
}
/**
* A class which invokes the correct storing instruction
*
*/
case class Store(reg1: Reg, alloc2: Allocated, _imm3: Option[Immediate] = None)
(using allocator: RegisterAllocator) extends Instruction {
// Emits the correct store instruction depending on the contents of src and dest
override def emit: String = getInstrs().map(_.emit).mkString("\n")
def getInstrs(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val regTemp = allocator.allocateTemp(X_REG)
_imm3 match {
case None => alloc2 match {
case reg2: Reg => instructions += STR(reg1, reg2)
case Spill(offset2, basePointer) => instructions += STR(reg1, basePointer, Some(offset2))
}
case Some(imm3) => alloc2 match {
case reg2: Reg => imm3 match {
// If the offset is less than -256, then we need to move the value to a register first
case Immediate(negVal) if negVal.toInt < -256 =>
instructions += MOV(X17, imm3)
instructions += STR(reg1, reg2, Some(X17))
// If the offset is greater than 255, then we need to move the value to a register first
case Immediate(posVal) if posVal.toInt > 255 =>
instructions += MoveOps(X17, imm3)
instructions += STR(reg1, reg2, Some(X17))
case _ =>
instructions += STR(reg1, reg2, Some(imm3))
}
case Spill(offset2, basePointer) =>
instructions += Load(regTemp, basePointer, Some(offset2))
instructions += STR(reg1, regTemp, Some(imm3))
}
}
allocator.freeTemp(regTemp)
instructions.toList
}
}
/**
* A class which invokes the correct compare instruction
*
*/
case class Compare(alloc1: Allocated, op2: Operand, _sxtw: Option[SXTW] = None)
(using allocator: RegisterAllocator) extends Instruction {
override def emit: String = getInstrs().map(_.emit).mkString("\n")
def getInstrs(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
alloc1 match {
case reg: Reg => instructions += CMP(reg, op2, _sxtw)
case Spill(offset, basePointer) =>
val regTemp = op2 match {
case reg: Reg if reg.bitMode == W_REG => allocator.allocateTemp(W_REG)
case _: Immediate => allocator.allocateTemp(W_REG)
case _ => allocator.allocateTemp(X_REG)
}
instructions += LDR(regTemp, basePointer, Some(offset))
instructions += CMP(regTemp, op2, _sxtw)
allocator.freeTemp(regTemp)
}
instructions.toList
}
}
/**
* A class which invokes the correct load instruction
*
*/
case class Load(dest: Reg, base: Allocated, offset: Option[Operand] = None, shift: Option[Shift] = None)
(using allocator: RegisterAllocator) extends Instruction {
override def emit: String = getInstrs().map(_.emit).mkString("\n")
def getInstrs(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
base match {
case reg: Reg =>
offset match {
// If the offset is less than -256, then we need to move the value to a register first
case Some(negImm@Immediate(negVal)) if negVal.toInt < -256 =>
instructions += MOV(X17, negImm)
instructions += LDR(dest, reg, Some(X17), shift)
// If the offset is greater than 255, then we need to move the value to a register first
case Some(posImm@Immediate(posVal)) if posVal.toInt > 255 =>
instructions += MoveOps(X17, posImm)
instructions += LDR(dest, reg, Some(X17), shift)
// Otherwise, just load the values as is
case _ => instructions += LDR(dest, reg, offset, shift)
}
case Spill(offsetBase, basePointer) =>
offset match {
// We load twice if the offset exists
case Some(_) =>
val tempReg = allocator.allocateTemp(X_REG)
instructions += LDR(tempReg, basePointer, Some(offsetBase))
instructions += LDR(dest, tempReg, offset, shift)
allocator.freeTemp(tempReg)
// We load straight to the destination
case None => instructions += LDR(dest, basePointer, Some(offsetBase))
}
}
instructions.toList
}
}

View File

@ -0,0 +1,399 @@
package wacc.backend
import scala.collection.mutable
// All libc definitions
val scanf = "scanf"
val exit = "exit"
val free = "free"
val puts = "puts"
val printf = "printf"
val fflush = "fflush"
val malloc = "malloc"
// All Standard Label definitions
val _prints = "_prints"
val _printc = "_printc"
val _printi = "_printi"
val _printb = "_printb"
val _printp = "_printp"
val _println = "_println"
val _readi = "_readi"
val _readc = "_readc"
val _malloc = "_malloc"
val _errOutOfMemory = "_errOutOfMemory"
val _errOutOfBounds = "_errOutOfBounds"
val _errBadChar = "_errBadChar"
val _errOverflow = "_errOverflow"
val _errDivZero = "_errDivZero"
val _errNull = "_errNull"
def _arrLoad(size: String) = s"_arrLoad$size"
def _arrStore(size: String) = s"_arrStore$size"
val _freepair = "_freepair"
/**
* LabelGenerator provides unique labels used for branch targets
* We use a simple counter that is incremented on every request
* Also produces standard labels (for read, print, and any other call that may be needed)
*/
class LabelGenerator {
private var branchCounter = -1
private var stringCounter = -1
private def prologue(label: String, reg1: Reg, reg2: Reg): List[Instruction] = List (
Label(label),
Comment(s"push {${reg1.toOperand}${if (reg2 != XZR) s", ${reg2.toOperand}" else ""}}"),
STP(reg1, reg2, SP)
)
private def epilogue(reg1: Reg, reg2: Reg): List[Instruction] = List (
Comment(s"pop {${reg1.toOperand}${if (reg2 != XZR) s", ${reg2.toOperand}" else ""}}"),
LDP(reg1, reg2, SP),
RET()
)
// Maps each label name to its full instruction set
private val widgetMap: Map[String, List[Instruction]] = Map(
_prints -> printInstr('s'), _printc -> printInstr('c'), _printi -> printInstr('i'),
_printb -> printInstr('b'), _printp -> printInstr('p'), _println -> printInstr('l'),
_readi -> readInstr('i'), _readc -> readInstr('c'), _malloc -> mallocInstr(),
_errOutOfMemory -> errInstr('m'), _errOutOfBounds -> errInstr('b'),
_errBadChar -> errInstr('c'), _errOverflow -> errInstr('o'), _errDivZero -> errInstr('z'),
_errNull -> errInstr('n'), _arrLoad("1") -> arrNInstr("Load", '1'),
_arrLoad("4") -> arrNInstr("Load", '4'), _arrLoad("8") -> arrNInstr("Load", '8'),
_arrStore("1") -> arrNInstr("Store", '1'), _arrStore("4") -> arrNInstr("Store", '4'),
_arrStore("8") -> arrNInstr("Store", '8'), _freepair -> freePairInstr()
)
// Contains the unique standard labels that have been called through the program
private val widgetSet: mutable.Set[String] = mutable.Set.empty
// A map to hold all string literals (so that duplicates reuse the same label)
private val stringLiterals: mutable.Map[String, (String, String)] = mutable.Map.empty
def getStringLiterals(): List[(String, (String, String))] = stringLiterals.toSeq.sortBy {
// Sort by the number of the string label
_._2._2.stripPrefix(".L.str").toInt
}.toList
/**
* Escapes special characters in a string literal
*/
def escapeString(s: String): String = {
s.flatMap {
case '\n' => "\\n"
case '\t' => "\\t"
case '\r' => "\\r"
case '\"' => "\\\""
case '\\' => "\\\\"
case c => c.toString
}
}
/**
* Generates all of the standard labels, this depends on how populated the set is
*/
def genWidgets(): List[Instruction] = {
widgetSet.toSeq.sortBy(str => str).toList.flatMap(EOF() +: widgetMap.get(_).get)
}
def addWidget(label: String): Unit = {
if (widgetMap.contains(label)) {
widgetSet += label
label match {
// Add the errOutOfMemory and prints labels for a malloc
case s"$mallocLabel" if mallocLabel == _malloc =>
widgetSet += _errOutOfMemory
widgetSet += _prints
// Add the errNull and prints labels for a freepair
case s"$freepairLabel" if freepairLabel == _freepair =>
widgetSet += _errNull
widgetSet += _prints
// For overflow, div by zer, or null errors, add the prints label
case s"$errLabel" if errLabel == _errOverflow || errLabel == _errDivZero || errLabel == _errNull =>
widgetSet += _prints
// Matches on the 3 arrLoad/arrStore standard labels and adds the _errOutOfBounds label to the set
case label if (label.startsWith(_arrLoad("")) || label.startsWith(_arrStore(""))) =>
widgetSet += _errOutOfBounds
case _ => ()
}
}
}
def freshLabel(str: String): String = {
// If the map doesn't contain the string, then generate a new label
if (!stringLiterals.contains(str)) {
stringCounter += 1
stringLiterals += (str -> (escapeString(str), s".L.str$stringCounter"))
}
// Return the label
stringLiterals.get(str).get._2
}
/**
* Generates a fresh label for a branching instruction (while/if-else)
* Returns the label string
*/
def freshLabel(): String = {
branchCounter += 1
s".L${branchCounter}"
}
// ----------------------------------------------------
// Standard labels
// ----------------------------------------------------
/**
* Formats print label instructions for chars, strings, ints, pointers, booleans, and println
*/
def printInstr(_type: Char): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val typeToEsc = Map('p' -> "%p", 'c' -> "%c", 'i' -> "%d")
val typePrinted = if (_type == 'l') "ln" else _type
instructions += Comment(s"length of .L._print${typePrinted}_str0", "")
// Add the meta data for the correct print type
_type match {
case chr @ ('p' | 'c' | 'i') =>
instructions ++= List(
Directive("word 2", " "),
Label(s".L._print${chr}_str0"),
Directive(s"asciz \"${typeToEsc.get(chr).get}\"", " ")
)
case 's' =>
instructions ++= List(
Directive("word 4", " "),
Label(s".L._prints_str0"),
Directive(s"asciz \"%.*s\"", " ")
)
case 'b' =>
instructions ++= List(
Directive("word 5", " "),
Label(s".L._printb_str0"),
Directive(s"asciz \"false\"", " "),
Comment(s"length of .L._printb_str1", ""),
Directive("word 4", " "),
Label(s".L._printb_str1"),
Directive(s"asciz \"true\"", " "),
Comment(s"length of .L._printb_str2", ""),
Directive("word 4", " "),
Label(s".L._printb_str2"),
Directive(s"asciz \"%.*s\"", " ")
)
case 'l' =>
instructions ++= List(
Directive("word 0", " "),
Label(s".L._println_str0"),
Directive(s"asciz \"\"", " ")
)
case _ => ()
}
// Common preamble for all print instructions
instructions += Directive("align 4")
instructions ++= prologue(s"_print${typePrinted}", LR, XZR)
// Generate the address manipulation depending on the type of print
_type match {
case 'p' | 'c' | 'i' =>
instructions += MoveOps(X1, X0)
case 's' =>
instructions += MoveOps(X2, X0)
instructions += LDUR(W1, X0)
case 'b' =>
instructions ++= List(
CMP(W0, Imm0),
new BCond(NE, ".L_printb0"),
ADR(X2, s".L._printb_str0"),
B(".L_printb1"),
Label(".L_printb0"),
ADR(X2, s".L._printb_str1"),
Label(".L_printb1"),
LDUR(W1, X2)
)
case _ => ()
}
// Generate the remaining routine depending on the type of print
if (_type == 'b') {
instructions += ADR(X0, s".L._printb_str2")
} else {
instructions += ADR(X0, s".L._print${typePrinted}_str0")
}
if (_type == 'l') {
instructions += new BL(puts)
} else {
instructions += new BL(printf)
}
instructions += MoveOps(X0, Imm0)
instructions += new BL(fflush)
instructions ++= epilogue(LR, XZR)
instructions.toList
}
/**
* Formats a standard read label instruction for chars, strings, ints, pointers, and booleans
*/
def readInstr(_type: Char): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val label = s".L._read${_type}_str0"
instructions += Comment(s"length of $label", "")
_type match {
case 'i' =>
instructions ++= List(
Directive("word 2", " "),
Label(label),
Directive(s"asciz \"%d\"", " ")
)
case 'c' =>
instructions ++= List(
Directive("word 3", " "),
Label(label),
Directive(s"asciz \" %c\"", " ")
)
case _ => ()
}
instructions += Directive("align 4")
instructions ++= prologue(label.replace(".L.", "").replace("_str0", ""), X0, LR)
instructions ++= List(
MoveOps(X1, SP),
ADR(X0, label),
new BL(scanf),
)
instructions ++= epilogue(X0, LR)
instructions.toList
}
def mallocInstr(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
instructions ++= prologue(_malloc, LR, XZR)
instructions ++= List(
new BL(malloc),
CBZ(X0, _errOutOfMemory),
)
instructions ++= epilogue(LR, XZR)
instructions.toList
}
def errInstr(errType: Char): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Create variables for the name and word length for the type
val (errName, errWord, errMsg) = errType match {
case 'm' => (_errOutOfMemory, "word 27", "out of memory")
case 'b' => (_errOutOfBounds, "word 42", "array index %d out of bounds")
case 'c' => (_errBadChar, "word 50", "int %d is not ascii character 0-127")
case 'o' => (_errOverflow, "word 52", "integer overflow or underflow occurred")
case 'z' => (_errDivZero, "word 40", "division or modulo by zero")
case 'n' => (_errNull, "word 45", "null pair dereferenced or freed")
case _ => ("", "", "")
}
// Add the preamble for instructions
instructions ++= List(
Comment(s"length of .L.${errName}_str0", ""),
Directive(errWord, " "),
Label(s".L.${errName}_str0"),
Directive(s"asciz \"fatal error: ${errMsg}\\n\"", " "),
Directive("align 4"),
Label(errName),
ADR(X0, s".L.${errName}_str0")
)
val remainingInstr = errType match {
case ('m' | 'o' | 'z' | 'n') => List(
new BL(_prints)
)
case ('b' | 'c') => List(
new BL(printf),
MoveOps(X0, Imm0),
new BL(fflush)
)
case _ => List()
}
instructions ++= remainingInstr
// Add on the final common part for all error labels
instructions ++= List(
MoveOps(W0, Imm_1),
new BL(exit)
)
instructions.toList
}
def arrNInstr(_type: String, n :Char): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
_type match {
case "Store" => Comment(s"Special calling convention: array ptr passed in X7, index in X17," +
" value to store in X8, LR(W30) is used as general register")
case "Load" => Comment("Special calling convention: array ptr passed in X7, index in X17," +
" LR(W30) is used as general register, and return into X7")
}
instructions ++= prologue(s"_arr$_type$n", LR, XZR)
instructions ++= List(
CMP(W17, Imm0),
Comment("this must be a 64-bit move so that it doesn't truncate if the move fails"),
CSEL(X1, X17, X1, LT),
new BCond(LT, _errOutOfBounds),
LDUR(W30, X7),
CMP(W17, W30),
Comment("this must be a 64-bit move so that it doesn't truncate if the move fails"),
CSEL(X1, X17, X1, GE),
new BCond(GE, _errOutOfBounds)
)
_type match {
case "Load" => n match {
case '8' => instructions += LDR(X7, X7, Some(X17), Some(LSL(Imm3)))
case '4' => instructions += LDR(W7, X7, Some(X17), Some(LSL(Imm2)))
case '1' => instructions += LDRB(W7, X7, X17)
case _ => ()
}
case "Store" => n match {
case '8' => instructions += STR(X8, X7, Some(X17), Some(LSL(Imm3)))
case '4' => instructions += STR(W8, X7, Some(X17), Some(LSL(Imm2)))
case '1' => instructions += STRB(W8, X7, X17)
case _ => ()
}
}
instructions ++= epilogue(LR, XZR)
instructions.toList
}
def freePairInstr(): List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
instructions ++= prologue(_freepair, LR, XZR)
instructions += CBZ(X0, _errNull)
instructions += new BL(free)
instructions ++= epilogue(LR, XZR)
instructions.toList
}
}

View File

@ -0,0 +1,311 @@
package wacc.backend
import wacc.frontend.syntax.ast._
import wacc.frontend.semantic.environment._
import scala.collection.mutable
/**
* Generates assembly for a RHS
*/
def genRhs(rhs: Rhs, cGen: CodeGenerator)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator,
gFuncEnv: GlobalFuncEnv, funcRegList: List[(String, Allocated)] = List.empty):
(List[Instruction], Reg) = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
rhs match {
case expr: Expr => genExpr(expr)
case NewPair(_, fst, snd) =>
// Save argument registers
instructions ++= cGen.generatePushInstructions(funcRegList, callingFunc = true)
instructions ++= List(
MoveOps(W0, Imm16),
BL(_malloc),
MoveOps(X16, X0)
)
// Pop off argument registers
instructions ++= cGen.generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
val (fstInstr, fstAlloc) = genExpr(fst)
val xFst = PhysicalReg(fstAlloc.regNumber, X_REG, fstAlloc.pool)
instructions ++= fstInstr
instructions += STR(xFst, X16)
allocator.freeTemp(fstAlloc)
val (sndInstr, sndAlloc) = genExpr(snd)
val xSnd = PhysicalReg(sndAlloc.regNumber, X_REG, sndAlloc.pool)
instructions ++= sndInstr
instructions += STR(xSnd, X16, Some(Imm8))
allocator.freeTemp(sndAlloc)
val pairAlloc = allocator.allocateTemp(X_REG)
instructions += MoveOps(pairAlloc, X16)
(instructions.toList, pairAlloc)
case arr@ArrayLiter(_, elems) =>
// Save argument registers
instructions ++= cGen.generatePushInstructions(funcRegList, callingFunc = true)
// Allocate the first temp
val tempAlloc1 = allocator.allocateTemp(W_REG)
// Create the immediate that is to be used for mallocing to the result register
val typeSize = sizeof(exprType(elems.headOption.getOrElse(PairLiter(IGNORE)))(using gMainEnv))
val arrImm = Immediate((typeSize * elems.length) + 4)
val elemsImm = Immediate(elems.length)
// Add the preamble for adding elements to the array
instructions ++= List(
Comment(s"${elems.length} element array"),
MoveOps(W0, arrImm),
BL(_malloc),
MoveOps(X16, X0)
)
// Pop off argument registers
instructions ++= cGen.generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
instructions ++= List(
Comment("array pointers are shifted forwards by 4 bytes (to account for size)"),
ADD(X16, X16, Imm4),
MoveOps(tempAlloc1, elemsImm),
STUR(tempAlloc1, X16)
)
allocator.freeTemp(tempAlloc1)
// Add each element to the memory location (depending on the type of the element)
instructions ++= elems.zipWithIndex.flatMap{(elem: Expr, index: Int) =>
// Allocate the temp allocation for both bit-modes as they are both needed
val (exprInstr, exprAlloc) = genExpr(elem)
val (exprAllocX, exprAllocW) = exprAlloc match {
case xReg@PhysicalReg(regNumber, bitMode, pool) =>
(xReg, PhysicalReg(regNumber, W_REG, pool))
}
// Add the specific element type
val indexImm = Immediate(index * typeSize)
val specificInstr = exprType(elem)(using gMainEnv) match {
case IntType => List(
STR(exprAllocW, X16, Some(indexImm))
)
case (CharType | BoolType) => List(
STRB(exprAllocW, X16, indexImm)
)
case _ => List(
STR(exprAllocX, X16, Some(indexImm))
)
}
allocator.freeTemp(exprAlloc)
exprInstr ++ specificInstr
}
// The final array is always in an x bit-mode register
val tempAlloc = allocator.allocateTemp(X_REG)
instructions += MoveOps(tempAlloc, X16)
(instructions.toList, tempAlloc)
case Call(_, id, args) =>
instructions += Comment("FUNCTION CALL")
// Store parameter args onto stack before moving values
instructions ++= cGen.generatePushInstructions(funcRegList, callingFunc = true)
// Add the needed stack space for extra arguments (args 9 onwards)
val _addedStackSpace = ((args.size - MAX_PARAM_REGS + 1) & ~1) * EIGHT_BYTES
val addedStackSpace = Immediate(_addedStackSpace)
if (_addedStackSpace > 0) {
instructions += SUB(SP, SP, addedStackSpace)
}
// Efficiently retrieves and argument and then moves the content to the correct arg register
instructions ++= args.zipWithIndex.flatMap{ case (arg, index) =>
val instrs: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
// Get the instructions and allocation for the argument
val (argInstr, argAlloc) = genExpr(arg)
// Create an allocation for a parameter, this will either be a register or memory on the stack
val paramAlloc = index match {
case paramReg if paramReg < MAX_PARAM_REGS => Register(s"$index", argAlloc.bitMode)
case paramSpill =>
val _offset = (paramSpill - MAX_PARAM_REGS) * EIGHT_BYTES
val offset = Immediate(_offset)
Spill(offset, SP)
}
// Add the argument's instructions and then move the argument's temp to the corresponding parameter reg
// Restoration of registers occurs in the declaration and assignment statements
instrs ++= argInstr
instrs += Move(paramAlloc, argAlloc)
allocator.freeTemp(argAlloc)
instrs.toList
}
// Call the function
// The id of the function used to query the global function environment
val funcKey = id.id
// Replace # with an underscore for funcLabel
val funcLabel = id.id.replaceAll("#", "_")
instructions += BL(funcLabel)
// Offload all stack manipulation for extra arguments
if (_addedStackSpace > 0) {
instructions += ADD(SP, SP, addedStackSpace)
}
// Get lhs type
val lhsType = gFuncEnv.lookup(funcKey).getOrElse(
throw new RuntimeException(s"Function $funcKey not found in global function environment, couldnt get lhs")
)._type
val returnReg = lhsType match {
case (IntType | CharType | BoolType) => W0
case _ => X0
}
// Registers will be popped after declaration / assignment to avoid overwriting X0
(instructions.toList, returnReg)
}
}
/**
* Generates assembly for a LHS
*/
def genLhs(lhs: Lhs)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator, funcRegList: List[(String, Allocated)] = List.empty):
(List[Instruction], Allocated) = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
lhs match {
// Differs to the expr identifier parsing as it just releases the result
// Needed to distinguish between assignments and just using the variable
case Ident(_, id) =>
val idAlloc = allocator.getAllocMapping(id)
(instructions.toList, idAlloc)
case PairElem(_, selector, expr) =>
val (exprInstr, exprAlloc) = genLhs(expr)
instructions += Compare(exprAlloc, Imm0)
instructions += BCond(EQ, _errNull)
instructions ++= exprInstr
(instructions.toList, exprAlloc)
// All other LHSs call genExpr and special cases are handled in the calling statement
case expr: Expr => genExpr(expr)
}
}
/**
* Generates assembly for a LHS considering special cases for Pair and Array Elems(when storing)
*/
def genLhsForUsing(lhs: Lhs, parentInstrs: mutable.ListBuffer[Instruction])
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
(List[Instruction], Allocated) = {
lhs match {
// In the case of an ArrayElem, you take the first n-1 indices
// where n is the length of the indices list
case ArrayElem(pos, id, indices) =>
val shortenedIndices = indices.take(indices.length - 1)
val (instr, tempAlloc) = genLhs(ArrayElem(pos, id, shortenedIndices))
// If the shortened query is nothing, then move the identifier's address
// to the address that's been generated
if (shortenedIndices.isEmpty) {
val idAlloc = allocator.getAllocMapping(id.id)
parentInstrs += Move(tempAlloc, idAlloc)
}
// Return the instructions and allocation generated
// instr is empty in the case of indicies being an empty list
(instr, tempAlloc)
// In the case of a PairElem, you need to unfold the PairElem to get the list of selectors
// and then iteratively load the values from the pair in reverse order
case _: PairElem =>
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
val (selectorList, finalIdent) = unfoldPairElem(lhs)
val (lhsInstr, lhsAlloc) = genLhs(finalIdent)
var tempAlloc = lhsAlloc
val reverse = selectorList.drop(1).reverse // Top case is handled in genStmt Assign
if (reverse.isEmpty) {
instructions += Compare(tempAlloc, Imm0)
instructions += BCond(EQ, _errNull)
} // We'll need to remove duplication of code between here & genLhs
instructions ++= lhsInstr
reverse.foreach { selector =>
val tempAlloc2 = allocator.allocateTemp(X_REG)
val offset = selector match {
case Fst => Some(Imm0)
case Snd => Some(Imm8)
}
instructions += Load(tempAlloc2, tempAlloc, offset)
allocator.freeTemp(tempAlloc)
instructions += Compare(tempAlloc, Imm0)
instructions += BCond(EQ, _errNull)
tempAlloc = tempAlloc2
}
(instructions.toList, tempAlloc)
case other => genLhs(lhs)
}
}
/*
* Unfolds a PairElem to get the list of selectors and the final LHS
*/
def unfoldPairElem(lhs: Lhs): (List[PairSelector], Lhs) = lhs match {
case PairElem(_, selector, innerLhs) =>
val (selectors, finalLhs) = unfoldPairElem(innerLhs)
(selector :: selectors, finalLhs)
case _ => (Nil, lhs)
}
/**
* Performs a move of the rhs to the lhs depending on the type of LHS
*/
def lhsMove(lhs: Lhs, lhsAlloc: Allocated, rhsAlloc: Reg, reg8: Reg)
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
List[Instruction] = {
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
lhs match {
// Call on _arrayStoreN function where N is the size of the type
case arrElem@ArrayElem(_, id, indices) =>
// Retrieve the final indexes instruction and allocation
val (finalIndexInstr, finalIndexAlloc) = genExpr(indices.last)
instructions ++= finalIndexInstr
instructions ++= List(
MoveOps(W17, finalIndexAlloc),
MoveOps(reg8, rhsAlloc),
Move(X7, lhsAlloc),
BL(_arrStore(s"${sizeof(lhsType(arrElem))}"))
)
allocator.freeTemp(finalIndexAlloc)
case PairElem(_, selector, expr) =>
selector match {
case Fst => instructions += Store(rhsAlloc, lhsAlloc)
case Snd => instructions += Store(rhsAlloc, lhsAlloc, Some(Imm8))
}
// Otherwise, move the contents of temp allocation to identifier allocation
case _ => instructions += Move(lhsAlloc, rhsAlloc)
}
instructions.toList
}

View File

@ -0,0 +1,288 @@
package wacc.backend
import scala.collection.mutable
// ----------------------------------------------------
// Constants
// ----------------------------------------------------
// Constant for the max number of variables that can be stored in a callee reg
val MAX_CALLEE_REGS = CalleePool.available.length
val MAX_PARAM_REGS = 8
// Byte constants
val SIXTEEN_BYTES = 16
val EIGHT_BYTES = 8
val X_REG = 'x'
val W_REG = 'w'
val FP_REG = "fp"
val SP_REG = "sp"
val LR_REG = "lr"
val XZR_REG = "xzr"
// Registers that are used through the program (particularly for standard labels)
val X0 = Register("0", X_REG)
val W0 = Register("0", W_REG)
val X1 = Register("1", X_REG)
val W1 = Register("1", W_REG)
val X2 = Register("2", X_REG)
val X7 = Register("7", X_REG)
val W7 = Register("7", W_REG)
val X8 = Register("8", X_REG)
val W8 = Register("8", W_REG)
val X16 = Register("16", X_REG)
val W16 = Register("16", W_REG)
val X17 = Register("17", X_REG)
val W17 = Register("17", W_REG)
// Special registers
val W30 = Register("30", W_REG)
val FP = Register(FP_REG, ' ')
val SP = Register(SP_REG, ' ')
val LR = Register(LR_REG, ' ')
val XZR = Register(XZR_REG, ' ')
// Immediate values
val Imm_1 = Immediate(-1)
val Imm0 = Immediate(0)
val Imm1 = Immediate(1)
val Imm2 = Immediate(2)
val Imm3 = Immediate(3)
val Imm4 = Immediate(4)
val Imm8 = Immediate(8)
val Imm16 = Immediate(16)
// Represents an allocated value for code generation
sealed trait Allocated {
def toOperand: String
}
// Representds an operand value for parsing into instructions (Can only be registers or immediate values)
sealed trait Operand {
def toOperand: String
}
// Label ops are for strings with the :lo12: prepended to the label
case class LabelOp(strLabel: String) extends Operand {
override def toOperand: String = s":lo12:$strLabel"
}
// Represents a physical register allocation
sealed trait Reg extends Allocated with Operand {
// Every register should have a register number and a bitmode
def regNumber: String
def bitMode: Char
override def toOperand: String = bitMode match {
case space if space == ' ' => s"${regNumber}"
case _ => s"${bitMode}${regNumber}"
}
}
object Reg {
// Contains the set of allowed bitmodes and register number that a reg can have
val allowedBitModes = Set(X_REG, W_REG, ' ')
val allowedRegNums = (0 to 31).map(_.toString()).toSet ++ Set(FP_REG, SP_REG, LR_REG, XZR_REG)
// Each application of a type of register must have the allowed parameters otherwise they can't be created
def validate(regNumber: String, bitMode: Char): Unit = {
if (!allowedRegNums.contains(regNumber) || !allowedBitModes.contains(bitMode)) {
throw IllegalArgumentException("Cannot have a register of such number/bitmode")
}
}
}
// A concrete physical register allocated from a specific pool
case class PhysicalReg(regNumber: String, bitMode: Char, pool: RegisterPool) extends Reg {
Reg.validate(regNumber, bitMode)
}
case class Register(regNumber: String, bitMode: Char) extends Reg {
Reg.validate(regNumber, bitMode)
}
// A class representing an immediate value
case class Immediate(value: String) extends Operand {
override def toOperand: String = s"#$value"
}
object Immediate {
def apply(value: String): Immediate = new Immediate(value)
def apply(value: BigInt): Immediate = new Immediate(s"$value")
}
// A spill slot allocation on the stack
// (Can change the basePointer if we want something other than FP)
case class Spill(offset: Immediate, basePointer: Reg = FP) extends Allocated {
override def toOperand: String = {
s"[${basePointer.toOperand}, ${offset.toOperand}]"
}
}
// A trait to tag register pools
sealed trait RegisterPool {
def available: List[String]
def name: String
}
// The callee-saved register pool (register numbers as strings)
case object CalleePool extends RegisterPool {
override val available: List[String] =
List("19", "20", "21", "22", "23", "24", "25", "26", "27", "28")
override val name: String = "callee"
}
// The temporary register pool (register numbers as strings)
case object TempPool extends RegisterPool {
override val available: List[String] =
List("8", "9", "10", "11", "12", "13", "14", "15")
override val name: String = "temp"
}
/**
*
* Allocates either Physical registers (x or w bit mode registers)
* It maintains two independent pools:
* one for callee registers and one for temporary registers.
* Allocation functions take a parameter (W_REG or X_REG) to determine the correct register type
*/
class RegisterAllocator {
private var calleeFreeRegNums: List[String] = CalleePool.available
private var tempFreeRegNums: List[String] = TempPool.available
// Map from an identifier to its allocated value
val identifierMap = mutable.Map[String, Allocated]()
// A counter to allocate unique spill slots (each 8 bytes)
// Default is the negative of the max num of callee registers (As we do things in reference to fp)
private var spillCounter: Int = -1 * MAX_CALLEE_REGS
private def allocateSpillSlot(): Immediate = {
spillCounter -= 1
Immediate(spillCounter * EIGHT_BYTES)
}
/**
* Adds an allocation mapping for an identifier
* The mapping preserves the exact register type (including its bit mode)
*/
def addAllocMapping(ident: String, reg: Allocated): Unit = {
identifierMap += (ident -> reg)
}
/**
* Retrieves the allocated value for an identifier
*/
def getAllocMapping(ident: String): Allocated =
identifierMap.getOrElse(ident, throw new RuntimeException(s"Identifier $ident not allocated"))
/**
* Gets the number of variables used in the program //TODO: May have to change this logic a bit later
*/
def getNumVariables(): BigInt = identifierMap.size
/**
* Allocates a temporary register with the given bit mode (W_REG or X_REG)
* Assumption: There is always an available temp register
*/
def allocateTemp(bitMode: Char): PhysicalReg =
tempFreeRegNums match {
case regNum :: rest =>
tempFreeRegNums = rest
PhysicalReg(regNum, bitMode, TempPool)
// Just a fall back which shouldn't happen according to our policy
case Nil => PhysicalReg("Temp Pool should be available", ' ', TempPool)
}
/**
* Allocates a callee register with the given bit mode (W_REG or X_REG)
* If no registers remain, a spill slot is allocated
*/
def allocateCallee(bitMode: Char): Allocated =
calleeFreeRegNums match {
case regNum :: rest =>
calleeFreeRegNums = rest
PhysicalReg(regNum, bitMode, CalleePool)
case Nil =>
Spill(allocateSpillSlot())
}
/**
* Frees a previously allocated temp register
* Callee registers cannot be freed
*/
def freeTemp(alloc: Allocated): Unit = alloc match {
case reg: PhysicalReg => reg.pool match {
case TempPool =>
if (!tempFreeRegNums.contains(reg.regNumber))
tempFreeRegNums = reg.regNumber :: tempFreeRegNums
// Can only free a temp reg
case CalleePool => ()
}
case _ => ()
}
/**
* Restores a temp if the pool isn't full
*/
def restore(): Unit = {
if (tempFreeRegNums.length < TempPool.available.length) {
// This is possible due to policy of Temp pool always being available
val restoredRegNum = tempFreeRegNums.headOption.getOrElse("16").toInt - 1
restoredRegNum match {
case _ if restoredRegNum >= 8 => tempFreeRegNums = s"$restoredRegNum" :: tempFreeRegNums
case _ => ()
}
}
}
/**
* Frees a previously allocated callee
* For physical registers, the register number is returned to the corresponding pool
* Spill slots are not reclaimed
* Temp registers cannot be freed
*/
def freeCallee(alloc: Allocated): Unit = alloc match {
case reg: PhysicalReg =>
reg.pool match {
// Can only free a callee reg
case TempPool => ()
case CalleePool =>
if (!calleeFreeRegNums.contains(reg.regNumber))
calleeFreeRegNums = reg.regNumber :: calleeFreeRegNums
}
case _ => // Other allocations are not reclaimed
}
/**
* Resets the allocator: returns all registers to their pools
* resets the spill counter, and clears identifier mappings
*/
def reset(): Unit = {
calleeFreeRegNums = CalleePool.available
tempFreeRegNums = TempPool.available
}
/**
* Saves all of the register information from before to later be restored
*/
def saveState(): (List[String], List[String], Int, mutable.Map[String, Allocated]) = {
(calleeFreeRegNums, tempFreeRegNums, spillCounter, identifierMap.clone())
}
/**
* Restores the state of the allocator
*/
def restoreState(state: (List[String], List[String], Int, mutable.Map[String, Allocated])): Unit = {
val (savedCalleeFreeRegNums, savedTempFreeRegNums, savedSpillCounter, savedIdentifierMap) = state
calleeFreeRegNums = savedCalleeFreeRegNums
tempFreeRegNums = savedTempFreeRegNums
spillCounter = savedSpillCounter
identifierMap.clear()
identifierMap ++= savedIdentifierMap
}
}

View File

@ -0,0 +1,75 @@
package wacc.backend
import wacc.frontend.syntax.ast._
import wacc.frontend.semantic.environment._
val IGNORE = (-1, -1)
/**
* Allocates the size of a type
*/
def sizeof(_type: Type) = _type match {
case IntType => 4
case (CharType | BoolType) => 1
case _ => 8
}
/**
* Finds the type of an expression
*/
def exprType(expr: Expr)(using gMainEnv: EnvTrait): Type =
expr match {
case _: (IntLiter | Neg | Len | Ord | Mul | Div | Mod | Add | Sub) => IntType
case _: (BoolLiter | Not | Gt | Geq | Lt | Leq | Eq | Neq | And | Or) => BoolType
case _: (CharLiter | Chr) => CharType
case _: StrLiter => StrType
case _: PairLiter => PairPlaceholder // The null type is a PairPlaceholder
case Ident(_, id) =>
gMainEnv.lookupType(id).getOrElse(AnyType)
case ArrayElem(_, id, indices) =>
// Peel off array layers from the type of the id
val arrType = exprType(id)
val finalType = {
indices.foldLeft(arrType) {
case (ArrayType(inner), _) => inner
case (_, _) => AnyType
}
}
finalType
case pairElem: PairElem => lhsType(pairElem)
case IfExpr(pos, cond, thenBranch, elseBranch) => exprType(thenBranch)
}
/**
* Finds the type of a RHS
*/
def rhsType(rhs: Rhs)
(using gMainEnv: EnvTrait, gFuncEnv: EnvTrait): Type = rhs match {
case expr: Expr => exprType(expr)(using gMainEnv)
case ArrayLiter(_, elems) =>
// An array may have at least 1 element where its type is found through the call to exprType
ArrayType(exprType(elems.headOption.getOrElse(PairLiter(IGNORE)))(using gMainEnv))
case NewPair(_, fst, snd) => (exprType(fst)(using gMainEnv), exprType(snd)(using gMainEnv)) match {
case (fType, sType) => PairType(fType, sType)
}
case Call(_, ident, args) =>
gFuncEnv.lookup(ident.id).get._type
}
/**
* Finds the type of a LHS
*/
def lhsType(lhs: Lhs)
(using gMainEnv: EnvTrait): Type = lhs match {
case identifier: (Ident | ArrayElem) => exprType(identifier)
case PairElem(_, selector, subLhs) =>
val subType = lhsType(subLhs)
(selector, subType) match {
case (Fst, PairType(PairPlaceholder, _)) => PairPlaceholder
case (Snd, PairType(_, PairPlaceholder)) => PairPlaceholder
case (Fst, PairType(t:Type, _)) => t
case (Snd, PairType(_, t:Type)) => t
// This is for any null case, so would be named an AnyType may have TODO a change on this
case _ => AnyType
}
}

View File

@ -0,0 +1,267 @@
package wacc.extension
import wacc.backend._
import scala.collection.mutable
class Peephole {
def optimize(instructions: List[Instruction]): List[Instruction] = {
Function.chain(List(removeRedundantLoadandStores, removeRedundantMoves))(instructions)
}
private def removeRedundantLoadandStores(instructions: List[Instruction]): List[Instruction] = {
val substitutions = mutable.ListBuffer[(Allocated | Operand, Reg, Int)]()
val removalList = mutable.ListBuffer[Int]()
instructions.zipWithIndex.foreach({
// Keep track of stored values, and index of instruction, and remove in situations where the register may be changed
case (STP(dest, src, base, _), index) =>
substitutions += ((dest, base, index))
substitutions += ((src, base, index))
case (STR(dest, base, _, _), index) => substitutions += ((dest, base, index))
case (STRB(dest, base, _), index) => substitutions += ((dest, base, index))
case (STUR(dest, base), index) => substitutions += ((dest, base, index))
// Given a load, check if the register is still stored in the substitutions, if so then the load-store is redundant
case (LDP(dest, src, base), index) =>
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
case Some((_, _, fstIndex)) =>
substitutions.find { case (operand, reg, _) => operand == src && reg == base } match {
case Some((_, _, sndIndex)) =>
removalList += sndIndex; removalList += index
case None =>
}
case None =>
}
case (LDR(dest, base, _, _), index) =>
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
case Some((_, _, fstIndex)) =>
removalList += fstIndex; removalList += index
case None =>
}
case (LDRB(dest, base, _), index) =>
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
case Some((_, _, fstIndex)) =>
removalList += fstIndex; removalList += index
case None =>
}
case (LDUR(dest, base), index) =>
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
case Some((_, _, fstIndex)) =>
removalList += fstIndex; removalList += index
case None =>
}
// Consider all cases where the register is modified, and remove the stored value
case (MoveOps(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
case (Move(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
case (ADD(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (SUB(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (SMULL(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (SDIV(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (ADDS(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (SUBS(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (MSUB(dest, _, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (NEGS(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
case (EOR(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
case (ADRP(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
case (ADR(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
case (CSET(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
case (CSEL(dest, _, _, _), _) => substitutions.filterInPlace(_._1 != dest)
// Clear the substitutions on branches given any register may be manipulated in the callee
case (B(_), _) => substitutions.clear()
case (BL(_, _), _) => substitutions.clear()
case (BCond(_, _, _), _) => substitutions.clear()
case (CBZ(_, _), _) => substitutions.clear()
case (Label(_), _) => substitutions.clear()
case _ =>
})
instructions.zipWithIndex.collect {
case (instr, index) if !removalList.contains(index) =>
instr
}
}
// When encountering a move, add the dest and src to a substitutions. if that dest is used in another instruction then substitute the dest with src
// given the original src remains unchanged. We then do a second pass that removes the redundant moves.
private def removeRedundantMoves(instructions: List[Instruction]): List[Instruction] = {
given allocator: RegisterAllocator = new RegisterAllocator()
val substitutions = mutable.ListBuffer[((Allocated | Operand, Allocated | Operand), Int)]() // List of (dest, src) to index
val substitutedLines = mutable.ListBuffer[((Allocated | Operand, Allocated | Operand), (Int, Int))]() // List of substitutions to their initial and subbed line
val removalList = mutable.ListBuffer[Int]() // List of indexes to remove
// Up until the dest/src in original move has been changed, you can substitute the dest for the src in subsequent instructions
def addSubstitution(dest: Allocated, src: List[Allocated | Operand], index: Int): Unit = {
src.foreach { source =>
substitutions.find(_._1._1 == source) match {
case Some(((srcDest, sub), subIndex)) =>
substitutedLines += (((srcDest, sub), (index, subIndex)))
removalList += subIndex
case None =>
}
}
// removes the dest from the substitutions list as both destination and source if it is changed
substitutions.filterInPlace(_._1._1 != dest)
substitutions.filterInPlace(_._1._2 != dest)
}
// Given a dest, src, and index, substitute the dest with src in the instructions
// If operandBool is true, then other types of operands (other than Reg) are allowed in operands 2 onwards
def useSubstitution(dest: Reg, src: List[Operand], index: Int, operandBool: Boolean): List[Operand] = {
val indexes = substitutedLines.filter(_._2._1 == index).map(_._2)
val subs = substitutedLines.filter(_._2._1 == index).map(_._1)
val mutableSrc = mutable.ListBuffer(src*)
subs.foreach {
case (subDest, sub: Reg) =>
mutableSrc.zipWithIndex.foreach { case (source, idx) =>
if (subDest == source) mutableSrc(idx) = sub
}
case (subDest, sub: Operand) if operandBool =>
mutableSrc.zipWithIndex.foreach { case (source, idx) =>
if (idx != 0 && subDest == source) mutableSrc(idx) = sub
}
case _ =>
}
mutableSrc.zip(src).zip(indexes).foreach { case ((a, b), i) =>
if (a == b) {
substitutedLines.filterInPlace(_._2 != i)
removalList.filterInPlace(_ != i._2)
}
}
mutableSrc.toList
}
// Initial pass, where substitutions are added and instructions are checked to see whether they can be substituted
instructions.zipWithIndex.foreach {
case (MoveOps(dest, src), index) if dest == src =>
removalList += index
case (Move(dest, src), index) if dest == src =>
removalList += index
case (MoveOps(dest, src), index) =>
src match {
// Immediate needs to be explicitly dealt with to avoid overflow substitution
case Immediate(value) =>
val valInt = value.toInt
if (valInt < 0 || valInt >= 0xFFFF) {
substitutions.filterInPlace(_._1._1 != dest)
substitutions.filterInPlace(_._1._2 != dest)
} else {
addSubstitution(dest, List(src), index)
substitutions += (((dest, src), index))
}
case _ =>
addSubstitution(dest, List(src), index)
substitutions += (((dest, src), index))
}
case (Move(dest, src), index) => addSubstitution(dest, List(src), index)
case (MSUB(dest, src1, src2, src3), index) => substitutions.clear()
case (CSEL(dest, src1, src2, src3), index) => substitutions.clear()
case (ADD(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (ADDS(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (SUB(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (SUBS(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (SMULL(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (SDIV(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (EOR(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
case (NEGS(dest, src), index) => addSubstitution(dest, List(src), index)
// Although these instructions do not use operand registers, addSubsitutions is still used to clear substitutions based on destination
case (ADRP(dest, src), index) => addSubstitution(dest, List(), index)
case (ADR(dest, src), index) => addSubstitution(dest, List(), index)
case (CSET(dest, src), index) => addSubstitution(dest, List(), index)
case (Compare(_, _, _), _) =>
case (CMP(_, _, _), _) =>
case (TST(_, _), _) =>
case (Comment(_, _), _) =>
case (Directive(_, _), _) =>
case (EOF(), _) =>
case (RET(), _) =>
// Clear substitutions on any instruction that hasnt been explicitly handled to maintain consistency.
case x => substitutions.clear()
}
// Second pass, where instructions are substituted given they were added in the first pass
val substituted = instructions.zipWithIndex.collect {
case (MoveOps(dest, src), index) if substitutedLines.exists(_._2._1 == index) =>
val indexes = substitutedLines.find(_._2._1 == index).get._2
val sub = substitutedLines.find(_._2._1 == index).get._1._2
sub match {
case x : Operand => MoveOps(dest, x)
case _ =>
substitutedLines.filterInPlace(_._2 != indexes)
removalList.filterInPlace(_ != indexes._2)
MoveOps(dest, src)
}
case (Move(dest, src), index) if substitutedLines.exists(_._2._1 == index) =>
val indexes = substitutedLines.find(_._2._1 == index).get._2
val sub = substitutedLines.find(_._2._1 == index).get._1._2
sub match {
case x : Allocated => Move(dest, x)
case x : Operand =>
dest match {
case y : Reg => MoveOps(y, x)
case _ =>
substitutedLines.filterInPlace(_._2 != indexes)
removalList.filterInPlace(_ != indexes._2)
Move(dest, src)
}
}
// We call useSubstitution for generic cases on all operands, and reconstruct instruction with substituted operands.
case (ADD(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, true)
ADD(dest, srcs(0), srcs(1))
case (ADDS(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, true)
ADDS(dest, srcs(0), srcs(1))
case (SUB(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, true)
SUB(dest, srcs(0), srcs(1))
case (SUBS(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, true)
SUBS(dest, srcs(0), srcs(1))
case (SMULL(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, false)
SMULL(dest, srcs(0), srcs(1))
case (SDIV(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, false)
SDIV(dest, srcs(0), srcs(1))
case (EOR(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src1, src2), index, true)
EOR(dest, srcs(0), srcs(1))
case (NEGS(dest, src), index) if substitutedLines.exists(_._2._1 == index) =>
val srcs = useSubstitution(dest, List(src), index, false)
NEGS(dest, srcs(0))
case (instr, index) =>
instr
}
// Finally, return the instructions with the redundant moves removed.
substituted.zipWithIndex.collect {
case (instr, index) if !removalList.contains(index) =>
instr
}
}
}

View File

@ -0,0 +1,185 @@
package wacc.frontend.semantic
import wacc.frontend.syntax.ast._
import wacc.backend.exprType
import scala.collection.mutable
object environment {
trait EnvTrait {
def lookup(name: String): Option[Symbol]
def lookupType(id: String): Option[Type]
}
// An environment for variable (and parameter) declarations
class Env(val parent: Option[Env]) extends EnvTrait {
val vars: mutable.Map[String, VarSymbol] = mutable.Map.empty
override def lookup(name: String): Option[VarSymbol] = vars.get(name) match {
case None => parent.flatMap(_.lookup(name))
case some => some
}
override def lookupType(id: String): Option[Type] = {
// For environments, we check on their raw input, not on the hashed value
val idKey = id.replaceAll("#\\d+$", "")
lookup(idKey) match {
case None => None
case Some(VarSymbol(_, _, _type, _)) => Some(_type)
}
}
def add(sym: VarSymbol)(using errors: mutable.ListBuffer[SemanticError]): Unit = {
if (!vars.contains(sym.id)) {
vars(sym.id) = sym
} else {
// Illegal declaration
errors += RedeclarationError(sym.pos, sym.id, inFunction = false)(using this)
}
}
}
// Global environment for functions
class GlobalFuncEnv extends EnvTrait {
// A map from the base identifier (not renamed) to a set of Tuples of
// n-arity of Types
val funcPoolMap: mutable.Map[String, mutable.ListBuffer[(List[Type], String)]] = mutable.Map.empty
// A map from the unique identifier to its func symbol
val funcs: mutable.Map[String, FuncSymbol] = mutable.Map.empty
// Looks up the function and may return its symbol
override def lookup(id: String): Option[FuncSymbol] = funcs.get(id)
override def lookupType(id: String): Option[Type] = funcs.get(id) match {
case None => None
case Some(FuncSymbol(_, _, _type, _, _)) => Some(_type)
}
/**
* Performs lookup for accessing the correct unique symbol for calls
* Does a rudimentary type analysis to match to the correct function (we don't actually check if the typing is correct)
*/
def callLookup(id: String, args: List[Expr], pos: (Int, Int), localEnv: Env)
(using errors: mutable.ListBuffer[SemanticError]): Option[FuncSymbol] = {
// Create a list of n-arity of Types where n is the length of args
val argSignature = args.map(exprType(_)(using localEnv))
// Get the number of matched functions and the last matched uniqueIds
val (matchedFunctions, uniqueId) = funcPoolMap.get(id) match {
case None => (0, "")
case Some(funcSignatures) => countMatches(funcSignatures.toList, argSignature)
}
// Match on the number of functions matched
// (Greater than 1 means ambiguity)
matchedFunctions match {
case 0 =>
// Undefined Function
errors += UndefinedFuncError(pos, id)
None
// Finally, get the uniqueId
case 1 => funcs.get(uniqueId)
case _ =>
// Ambiguous Function call
errors += AmbiguousFuncCallError(pos, id)
None
}
}
// Potentially adds a FuncSymbol to the funcs map but first checks if its parameters exist in
// the pool map. If they do, then we add a redeclaration error
def add(funcSymbol: FuncSymbol)(using errors: mutable.ListBuffer[SemanticError]): Unit = {
// Extract needed parts of funcSymbol
val FuncSymbol(id, uniqueId, _, paramSymbs, pos) = funcSymbol
// Create a list of n-arity of Types where n is the length of parameters
val paramTypes = getParamTypes(paramSymbs)
// Find the list of function signatures from the funcPoolMap
val funcSignatures = funcPoolMap.getOrElse(id, mutable.ListBuffer.empty)
val (matchedFunctions, _) = countMatches(funcSignatures.toList, paramTypes)
matchedFunctions match {
// We now check the uniqueness of paramTypes
case 0 =>
// Create the tuple of paramTypes with the unique id of the function
val typeWithId = (paramTypes, uniqueId)
funcPoolMap.getOrElseUpdate(id, mutable.ListBuffer.empty) += typeWithId
funcs += (uniqueId -> funcSymbol)
// Illegal declaration
case _ =>
errors += RedeclarationError(pos, id, inFunction = true, paramTypes)(using this)
funcs += (id -> funcSymbol)
}
}
/**
* Finds the number of matches that an argSignature makes with the funcSignatures that are known to the program thus far
* Returns the number of matches and the last seen unique id
*/
def countMatches(funcSignatures: List[(List[Type], String)], argSignature: List[Type]): (Int, String) = {
// Get the candidate functions with the same arity as the called function
val matchCandidates = funcSignatures.filter(_._1.size == argSignature.size)
var matchedFunctions = 0
var uniqueId = ""
matchCandidates.foreach { candidate =>
// For each candidate, we check if the signatures match
val (paramSignature, uId) = candidate
val isMatch = argSignature.zip(paramSignature).foldLeft(true) { (acc, zippedSignature) =>
// Extract the arg and paramter type, then see if the argument type casts to the parameter
val (argType, paramType) = zippedSignature
(argType typecast Some(paramType)) && acc
}
// We update the functions matched and unique id
if (isMatch) {
matchedFunctions += 1
uniqueId = uId
}
}
(matchedFunctions, uniqueId)
}
/**
* Create a list of the types that the parameter has
*/
private def getParamTypes(paramSymbs: List[VarSymbol]): List[Type] = paramSymbs.map { paramSym =>
// Extract the paramType from paramSym
val VarSymbol(_, _, paramType, _) = paramSym
// Output the paramType
paramType
}
}
// Global environment for main body
class GlobalMainEnv extends EnvTrait {
val vars: mutable.Map[String, VarSymbol] = mutable.Map.empty
def lookup(id: String): Option[Symbol] = vars.get(id)
override def lookupType(id: String): Option[Type] = vars.get(id) match {
case Some(VarSymbol(_, _, _type, _)) => Some(_type)
case None => None
}
def add(_var: VarSymbol)(using errors: mutable.ListBuffer[SemanticError] = mutable.ListBuffer.empty): Unit = {
vars += (_var.uniqueName -> _var)
}
}
}

View File

@ -0,0 +1,274 @@
package wacc.frontend.semantic
import scala.collection.mutable
import wacc.frontend.semantic.environment._
import wacc.frontend.syntax.ast._
class Renamer {
// A counter to generate unique names
private var counter: Int = 0
private def generateUnique(name: String): String = {
counter += 1
s"$name#$counter"
}
// Renames the entire program, returning a new WProgram with uniquely renamed identifiers
// It first updates the global function environment, then renames the main block and each function
def renameProgram(program: WProgram):
(WProgram, mutable.ListBuffer[SemanticError], GlobalFuncEnv, GlobalMainEnv) = {
// Mutable list of errors found while traversing AST for renaming
// (Would be populated with scope errors for now)
given errors: mutable.ListBuffer[SemanticError] = mutable.ListBuffer.empty
// Global Environments for both Functions and the Main body
given gFuncEnv: GlobalFuncEnv = new GlobalFuncEnv
given gMainEnv: GlobalMainEnv = new GlobalMainEnv
// Create a new environment for starting a block block
{
given localEnv: Env = new Env(None)
given inFunction: Boolean = false
// Rename each function.
val newFuncs = renameFuncs(program.funcs)
// Rename main block statements.
val newStats = program.stats.map(renameStmt)
// We'll keep the errors found so far and a new AST with renamed nodes
(WProgram(program.pos, program.imports, newFuncs, newStats), errors, gFuncEnv, gMainEnv)
}
}
/// Renames all function definitions
def renameFuncs(funcs: List[Func])
(using errors: mutable.ListBuffer[SemanticError],
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): List[Func] = {
// Contains a funcSymbol with a list of its parameters, body, and associated environment
val fGroups: List[(FuncSymbol, List[Param], List[Stmt], Env)] = funcs.map { f =>
// Gen a unique name for the function
val uniqueFuncName = generateUnique(f.id.id)
// Get the param symbol and param lists at the same time to feed into
// making a FuncSymbol and Func
val (paramSymbols, newParams) = f.args.map { p =>
val uniqueParamName = generateUnique(p.id.id)
val newIdent = Ident(p.id.pos, uniqueParamName)
(VarSymbol(p.id.id, uniqueParamName, p.t, p.pos), Param(p.pos, p.t, newIdent))
}.unzip
// Create a funcSym using the paramSymbols and uniqueFuncName created and then add to
// GlobalFuncEnv
val funcSym = FuncSymbol(f.id.id, uniqueFuncName, f.t, paramSymbols, f.id.pos)
gFuncEnv.add(funcSym)
// Create an environment to add the paramSymbols
val funcEnv = new Env(None)
paramSymbols.map(funcEnv.add)
(funcSym, newParams, f.body, funcEnv)
}
val newFuncs = fGroups.map { fGroup =>
val (funcSym, newParams, body, funcEnv) = fGroup
given localEnv: Env = new Env(Some(funcEnv))
given inFunction: Boolean = true
// Create the new identity and body
val newIdent = Ident(funcSym.pos, funcSym.uniqueName)
val newBody = body.map((stmt: Stmt) => renameStmt(stmt))
Func(funcSym.pos, funcSym._type, newIdent, newParams, newBody)
}
newFuncs
}
// Rename the elements within statements while checking the validity of identifiers
def renameStmt(stmt: Stmt)
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Stmt = {
stmt match {
case Skip(pos) => Skip(pos)
case Assign(pos, lhs, rhs) => Assign(pos, renameLhs(lhs), renameRhs(rhs))
case Read(pos, lhs) => Read(pos, renameLhs(lhs))
case Free(pos, expr) => Free(pos, renameExpr(expr))
case Return(pos, expr) => Return(pos, renameExpr(expr))
case Exit(pos, expr) => Exit(pos, renameExpr(expr))
case Print(pos, expr) => Print(pos, renameExpr(expr))
case Println(pos, expr) => Println(pos, renameExpr(expr))
// Add a new symbol to the environment on declaration
case Declare(pos, t, ident, rhs) => {
val newRhs = renameRhs(rhs)
val uniqueName = generateUnique(ident.id)
val newIdent = Ident(ident.pos, uniqueName)
val varSym = VarSymbol(ident.id, uniqueName, t, ident.pos)
// Add the new symbol to the local and global variable environment
localEnv.add(varSym)
gMainEnv.add(varSym)
Declare(pos, t, newIdent, newRhs)
}
case If(pos, cond, thenCase, elseCase) => {
val newCond = renameExpr(cond)
// Create new environments for each branch.
val thenEnv = new Env(Some(localEnv))
val newThen = {
given localEnv: Env = thenEnv
thenCase.map(renameStmt)
}
val elseEnv = new Env(Some(localEnv))
val newElse = {
given localEnv: Env = elseEnv
elseCase.map(renameStmt)
}
If(pos, newCond, newThen, newElse)
}
case While(pos, cond, body) => {
val newCond = renameExpr(cond)
val bodyEnv = new Env(Some(localEnv))
val newBody = {
given localEnv: Env = bodyEnv
body.map(renameStmt)
}
While(pos, newCond, newBody)
}
case Block(pos, stmts) => {
val blockEnv = new Env(Some(localEnv))
val newStmts = {
given localEnv: Env = blockEnv
stmts.map(renameStmt)
}
Block(pos, newStmts)
}
}
}
// Renames all expressions which may contain identities and ignores literals
def renameExpr(expr: Expr)
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Expr = {
expr match {
case Not(x) => Not(renameExpr(x))
case Neg(x) => Neg(renameExpr(x))
case Len(x) => Len(renameExpr(x))
case Ord(x) => Ord(renameExpr(x))
case Chr(x) => Chr(renameExpr(x))
case Mul(x, y) => Mul(renameExpr(x), renameExpr(y))
case Div(x, y) => Div(renameExpr(x), renameExpr(y))
case Mod(x, y) => Mod(renameExpr(x), renameExpr(y))
case Add(x, y) => Add(renameExpr(x), renameExpr(y))
case Sub(x, y) => Sub(renameExpr(x), renameExpr(y))
case Gt(x, y) => Gt(renameExpr(x), renameExpr(y))
case Geq(x, y) => Geq(renameExpr(x), renameExpr(y))
case Lt(x, y) => Lt(renameExpr(x), renameExpr(y))
case Leq(x, y) => Leq(renameExpr(x), renameExpr(y))
case Eq(x, y) => Eq(renameExpr(x), renameExpr(y))
case Neq(x, y) => Neq(renameExpr(x), renameExpr(y))
case And(x, y) => And(renameExpr(x), renameExpr(y))
case Or(x, y) => Or(renameExpr(x), renameExpr(y))
case ident: Ident => localEnv.lookup(ident.id) match {
case Some(varSym) => Ident(ident.pos, varSym.uniqueName)
case None => {
// Undeclared Variable
errors += UndefinedValError(ident.pos, ident.id)
ident
}
}
case ArrayElem(pos, id, indices) => localEnv.lookup(id.id) match {
case Some(varSym) =>
ArrayElem(pos, Ident(id.pos, varSym.uniqueName), indices.map(renameExpr))
case None => {
// Undeclared Array
errors += UndefinedValError(id.pos, id.id)
ArrayElem(pos, id, indices.map(renameExpr))
}
}
case IfExpr(pos, cond, thenBranch, elseBranch) => {
val newCond = renameExpr(cond)
val thenEnv = new Env(Some(localEnv))
val newThen = {
given env: Env = thenEnv
renameExpr(thenBranch)
}
val elseEnv = new Env(Some(localEnv))
val newElse = {
given env: Env = elseEnv
renameExpr(elseBranch)
}
IfExpr(pos, newCond, newThen, newElse)
}
// The rest are literals so no identities to rename here
case rest => rest
}
}
def renameRhs(rhs: Rhs)
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Rhs = {
rhs match {
case ArrayLiter(pos, elems) => ArrayLiter(pos, elems.map(renameExpr))
case NewPair(pos, fst, snd) => NewPair(pos, renameExpr(fst), renameExpr(snd))
case Call(pos, id, args) => {
val newArgs = args.map(renameExpr)
val funcSym = gFuncEnv.callLookup(id.id, newArgs, pos, localEnv).getOrElse {
FuncSymbol(id.id, generateUnique(id.id), AnyType, Nil, id.pos)
}
val newId = Ident(id.pos, funcSym.uniqueName)
Call(pos, newId, newArgs)
}
case PairElem(pos, selector, lhs) => PairElem(pos, selector, renameLhs(lhs))
case expr: Expr => renameExpr(expr)
}
}
def renameLhs(lhs: Lhs)
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Lhs = {
lhs match {
case ident: Ident => localEnv.lookup(ident.id) match {
case Some(varSym) => Ident(ident.pos, varSym.uniqueName)
case None => {
// Undeclared Variable
errors += UndefinedValError(ident.pos, ident.id)
ident
}
}
case ArrayElem(pos, id, indices) => localEnv.lookup(id.id) match {
case Some(varSym) =>
ArrayElem(pos, Ident(id.pos, varSym.uniqueName), indices.map(renameExpr))
case None => {
// Undeclared Array
errors += UndefinedValError(id.pos, id.id)
ArrayElem(pos, id, indices.map(renameExpr))
}
}
case PairElem(pos, selector, lhs) => PairElem(pos, selector, renameLhs(lhs))
}
}
}

View File

@ -0,0 +1,18 @@
package wacc.frontend.semantic
import wacc.frontend.semantic.environment.GlobalFuncEnv
import wacc.frontend.semantic.environment.GlobalMainEnv
import wacc.frontend.syntax.ast.WProgram
def analyse(program: WProgram):
Either[(WProgram, GlobalFuncEnv, GlobalMainEnv), List[SemanticError]] = {
val renamer = new Renamer()
val (renamedAST, errors, globalFuncEnv, globalMainEnv) = renamer.renameProgram(program)
val (renamedProgram, errs) =
typeChecking.validateAST(renamedAST, globalFuncEnv.funcs.toMap)(using globalMainEnv, errors)
errs match {
case Nil => Left((renamedProgram, globalFuncEnv, globalMainEnv))
case _ => Right(errs)
}
}

View File

@ -0,0 +1,156 @@
package wacc.frontend.semantic
import scala.io.Source
import wacc.frontend.semantic.environment._
import wacc.frontend.syntax.ast.Type
val SEMANTIC_ERROR_CODE = 200
sealed trait SemanticError {
def getPosition: (Int, Int)
def getMessage(fileName: String, path: String): String
}
// ====================================================
// ERRORS FOUND DURING SCOPE ANALYSIS
// ====================================================
case class UndefinedFuncError(pos: (Int, Int), id: String) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Undefined function error in $fileName" +
s" (${pos._1},${pos._2}):\n function $id is undefined:" +
s"\n$line\n"
}
}
// TODO: May add on the number of matches
case class AmbiguousFuncCallError(pos: (Int, Int), id: String) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Ambiguous function call in $fileName" +
s" (${pos._1},${pos._2}):\n function call for $id is ambiguous:" +
s"\n$line\n"
}
}
case class UndefinedValError(pos: (Int, Int), id: String)(using env: Env)
extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
val unpackedEnv = unpackEnv(env)
val relevantInScope = s" relevant in-scope variables include:\n${unpackEnv(env)}\n | $line\n"
s"Scope error in $fileName (${pos._1},${pos._2}):\n" +
s" variable $id has not been declared in this scope:\n" +
s"${if (unpackedEnv.size > 3) then relevantInScope else " there are no relevant in-scope variables\n"}"
}
// Return a list of `\n` separated declarations, each line `_type _ident`.
def unpackEnv(e: Env): String = gatherVars(e)
.map(sym => s" ${sym._type.getType} ${sym.id} (declared on line ${sym.pos._1})")
.mkString("\n")
// Recursively collect all visible symbols from `env` up to outer parents.
def gatherVars(e: Env): List[VarSymbol] = {
e.vars.values.toList ++
(e.parent match {
case Some(p) => gatherVars(p)
case None => Nil
})
}
}
case class RedeclarationError(pos: (Int, Int), id: String, inFunction: Boolean, paramTypes: List[Type] = Nil)
(using env: EnvTrait) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
if (inFunction) {
s"Function redefinition error in $fileName (${pos._1},${pos._2}):\n" +
s" illegal redefinition of function $id with parameters (${paramTypes.map(_.getType).mkString(", ")})\n" +
s" previously declared on " +
s"line ${env.lookup(id).get.pos._1}\n | $line\n"
} else {
s"Scope error in $fileName (${pos._1},${pos._2}):\n" +
s" illegal redeclaration of variable $id\n" +
s" previously declared (in this scope) on " +
s"line ${env.lookup(id).get.pos._1}\n | $line\n"
}
}
}
// ====================================================
// ERRORS FOUND DURING TYPE CHECKING
// ====================================================
case class InvalidTypeError(pos: (Int, Int), details: (String, String)) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Type error in $fileName (${pos._1},${pos._2}):\n" +
s" unexpected ${details._1}\n" +
s" expected ${details._2}\n" +
s" | $line\n"
}
}
case class PairAssignmentError(pos: (Int, Int), details: (String, String)) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Type error in $fileName (${pos._1},${pos._2}):\n" +
s" ${details._2}\n" +
s" | $line\n"
}
}
case class InvalidConditionError(pos: (Int, Int), details: (String, String)) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Condition error in $fileName (${pos._1},${pos._2}):\n" +
s" unexpected ${details._1}\n" +
s" ${details._2}\n" +
s" | $line\n"
}
}
case class FunctionCallError(pos: (Int, Int), details: (String, String)) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Function call error in $fileName (${pos._1},${pos._2}):\n" +
s" ${details._2}\n" +
s" | $line\n"
}
}
case class GeneralTypeError(pos: (Int, Int), details: (String, String)) extends SemanticError {
override def getPosition: (Int, Int) = pos
override def getMessage(fileName: String, path: String): String = {
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
s"Type error in $fileName (${pos._1},${pos._2}):\n" +
s" unexpected ${details._1}\n" +
s" ${details._2}\n" +
s" | $line\n"
}
}
private def getLineFromFile(path: String, lineNumber: Int): Option[String] = {
try {
val source = Source.fromFile(path)
val line = source.getLines().drop(lineNumber - 1).nextOption() // Get the specific line
source.close()
line
} catch {
case _: Exception => None // Handle errors if file can't be read
}
}

View File

@ -0,0 +1,20 @@
package wacc.frontend.semantic
import wacc.frontend.syntax.ast._
// Base trait for a resolved symbol.
sealed trait Symbol {
val id: String
val _type: Type
val pos: (Int, Int)
}
// A symbol representing a variable (or function parameter).
// The field `uniqueName` holds the new (alphaconverted) name.
case class VarSymbol(id: String, uniqueName: String, _type: Type, pos: (Int, Int))
extends Symbol
// A symbol representing a function.
// The field `uniqueName` holds the new (alphaconverted) name.
case class FuncSymbol(id: String, uniqueName: String, _type: Type,
params: List[VarSymbol], pos: (Int, Int)) extends Symbol

View File

@ -0,0 +1,293 @@
package wacc.frontend.semantic
import wacc.frontend.semantic.environment.GlobalMainEnv
import wacc.frontend.syntax.ast._
import scala.collection.mutable
object typeChecking {
// Chooses the correct error type
private def addSemanticError(pos: (Int, Int), offending: String, message: String,
errConstructor: ( (Int, Int), (String, String) ) => SemanticError)
(using errors: mutable.ListBuffer[SemanticError]): Unit =
errors += errConstructor(pos, (offending, message))
def validateAST(program: WProgram, fMap: Map[String, FuncSymbol])
(using globalMainEnv: GlobalMainEnv, errors: mutable.ListBuffer[SemanticError]):
(WProgram, List[SemanticError]) = {
given funcMap: Map[String, FuncSymbol] = fMap
val WProgram(pos, imports, funcs, stmts) = program
funcs.foreach(validateFunc)
stmts.foreach(validateStmt)
(program, errors.sortBy(semErr => semErr.getPosition).toList)
}
def validateFunc(func: Func)
(using funcMap: Map[String, FuncSymbol], globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Unit = {
val Func(pos, retType, ident, params, stmts) = func
stmts.foreach(validateStmt(_)(using funcMap, Some(retType), params))
}
def validateStmt(stmt: Stmt)
(using funcMap: Map[String, FuncSymbol], funcReturnType: Option[Type] = None,
paramList: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Unit = stmt match {
case Skip(_) => ()
case Print(pos, expr) => findExprType(pos, expr)
case Println(pos, expr) => findExprType(pos, expr)
case Declare(pos, declType, ident, rhs) =>
findRhsType(pos, rhs).foreach { rhsType =>
if rhsType != PairPlaceholder && !(rhsType typecast Some(declType)) then
addSemanticError(pos, s"${declType.getType}", s"${rhsType.getType}", InvalidTypeError.apply)
}
case Assign(pos, lhs, rhs) => (findLhsType(pos, lhs), findRhsType(pos, rhs)) match {
case (Some(PairPlaceholder), Some(PairPlaceholder)) =>
addSemanticError(pos, s"", "Cannot assign unknown types (both PairPlaceholder)",
PairAssignmentError.apply)
case (Some(lhsType), Some(rhsType)) =>
// highlight: unify the check if both AnyType into a single guard
if lhsType == AnyType && rhsType == AnyType then
addSemanticError(pos, s"${lhsType.getType}", "Cannot assign AnyType to AnyType", GeneralTypeError.apply)
// highlight: more concise checks for PairPlaceholder and PairType
val lhsIsPartialPair =
lhsType == PairPlaceholder && rhsType.isInstanceOf[PairType]
val rhsIsPartialPair =
rhsType == PairPlaceholder && lhsType.isInstanceOf[PairType]
if !lhsIsPartialPair && !rhsIsPartialPair &&
!(rhsType typecast Some(lhsType))
then
addSemanticError(pos, s"${lhsType.getType}", s"${rhsType.getType}", InvalidTypeError.apply)
case _ => ()
}
case Read(pos, lhs) => findLhsType(pos, lhs) match {
case Some(IntType) | Some(CharType) =>
case lhsTypeOpt => addSemanticError(pos, s"${lhsTypeOpt.getOrElse(AnyType).getType}",
"Invalid type for Read (must be int or char)", GeneralTypeError.apply)
}
case Free(pos, expr) => findExprType(pos, expr) match {
case Some(ArrayType(_)) | Some(PairType(_, _)) => ()
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
"Invalid type cannot be freed (must be a pair or array)", GeneralTypeError.apply)
}
case Return(pos, expr) =>
if funcReturnType.isEmpty then
addSemanticError(pos, s"", "Return statement outside of function", PairAssignmentError.apply)
findExprType(pos, expr).foreach { exprType =>
if !(exprType typecast funcReturnType) then
addSemanticError(pos,
s"unexpected ${exprType.getType}\n expected ${funcReturnType.getOrElse(AnyType).getType}",
"Invalid return type", PairAssignmentError.apply)
}
case Exit(pos, expr) => findExprType(pos, expr) match {
case Some(IntType) => ()
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
"Exit expression must be an int", GeneralTypeError.apply)
}
case Block(_, blockStmts) => blockStmts.foreach(validateStmt)
case If(pos, cond, thenCase, elseCase) => findExprType(pos, cond) match {
case Some(BoolType) =>
thenCase.foreach(validateStmt)
elseCase.foreach(validateStmt)
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
"Invalid condition type in If (must be bool)", InvalidConditionError.apply)
}
case While(pos, cond, body) => findExprType(pos, cond) match {
case Some(BoolType) => body.foreach(validateStmt)
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
"Invalid condition type in While (must be bool)", InvalidConditionError.apply)
}
}
def findExprType(pos: (Int, Int), expr: Expr)
(using params: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Option[Type] = expr match {
case Ident(_, id) =>
globalMainEnv.lookupType(id).orElse(params.find(_.id.id == id).map(_.t))
case IntLiter(_, _) => Some(IntType)
case StrLiter(_, _) => Some(StrType)
case BoolLiter(_, _) => Some(BoolType)
case CharLiter(_, _) => Some(CharType)
case PairLiter(_) => Some(PairPlaceholder)
case Not(x) => if unaryOpTypeCheck(pos, x, List(BoolType)) then Some(BoolType) else None
case Neg(x) => if unaryOpTypeCheck(pos, x, List(IntType)) then Some(IntType) else None
case Len(x) => if unaryOpTypeCheck(pos, x, List(ArrayType(AnyType))) then Some(IntType) else None
case Ord(x) => if unaryOpTypeCheck(pos, x, List(CharType)) then Some(IntType) else None
case Chr(x) => if unaryOpTypeCheck(pos, x, List(IntType)) then Some(CharType) else None
case Div(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
case Mul(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
case Mod(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
case Add(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
case Sub(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
case Geq(x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
case Gt (x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
case Lt (x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
case Leq(x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
case Eq (x, y) => if binOpTypeCheck(pos, x, y, List(AnyType)) then Some(BoolType) else None
case Neq(x, y) => if binOpTypeCheck(pos, x, y, List(AnyType)) then Some(BoolType) else None
case And(x, y) => if binOpTypeCheck(pos, x, y, List(BoolType)) then Some(BoolType) else None
case Or (x, y) => if binOpTypeCheck(pos, x, y, List(BoolType)) then Some(BoolType) else None
case PairElem(_, _, lhs) => findLhsType(pos, lhs)
case ArrayElem(_, ident, exprs) => arrayAccessType(pos, ident, exprs)
case IfExpr(pos, cond, thenBranch, elseBranch) =>
val thenType = findExprType(pos, thenBranch)
val elseType = findExprType(pos, elseBranch)
if (thenType.exists(_ typecast elseType)) then thenType
else if (elseType.exists(_ typecast thenType)) then elseType
else {
addSemanticError(pos, s"${thenType.getOrElse(AnyType).getType}, ${elseType.getOrElse(AnyType).getType}",
"Incompatible types in If-Expr branches", GeneralTypeError.apply)
None
}
}
def unaryOpTypeCheck(pos: (Int, Int), x: Expr, requiredTypes: List[Type])
(using params: List[Param], globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Boolean = findExprType(pos, x) match {
case Some(lhsType) if requiredTypes.exists(_ typecast Some(lhsType)) => true
case typeOpt =>
addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType} ", "Invalid type in unary operator", InvalidTypeError.apply)
false
}
def binOpTypeCheck(pos: (Int, Int), x: Expr, y: Expr, requiredTypes: List[Type])
(using params: List[Param], globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Boolean = {
val lhsOpt = findExprType(pos, x)
val rhsOpt = findExprType(pos, y)
(lhsOpt, rhsOpt) match {
case (Some(lhsType), Some(rhsType))
if (requiredTypes.exists(_ typecast lhsOpt) || requiredTypes == List(AnyType))
&& (lhsType typecast rhsOpt) => true
case _ =>
addSemanticError(pos, s"got ${lhsOpt.getOrElse(AnyType).getType}, ${rhsOpt.getOrElse(AnyType).getType}",
s"Invalid types in binary operator", GeneralTypeError.apply)
false
}
}
def findRhsType(pos: (Int, Int), rhs: Rhs)
(using funcMap: Map[String, FuncSymbol], params: List[Param] = Nil,
globalMainEnv: GlobalMainEnv, errors: mutable.ListBuffer[SemanticError]):
Option[Type] = rhs match {
case ArrayLiter(_, elems) =>
val elemTypes = elems.flatMap(e => findExprType(pos, e))
elemTypes.distinct match {
case t :: Nil => Some(ArrayType(t))
case Nil => Some(ArrayType(AnyType)) // empty array => AnyType
case multiple =>
val lcaType = multiple.reduceLeft { (t1, t2) =>
if t1.typecast(Some(t2)) then t2
else if t2.typecast(Some(t1)) then t1
else
addSemanticError(pos, s"[${elemTypes.map(elem => elem.getType).mkString(", ")}]", "Multiple incompatible types in array",
GeneralTypeError.apply)
AnyType
}
Some(ArrayType(lcaType))
}
case NewPair(_, fst, snd) => (findExprType(pos, fst), findExprType(pos, snd)) match {
case (Some(f), Some(s)) => Some(PairType(f, s))
case _ =>
addSemanticError(pos, s"", "One of those types is invalid for newpair",
PairAssignmentError.apply)
None
}
case PairElem(_, selector, _) => findLhsType(pos, rhs.asInstanceOf[Lhs])
case Call(callPos, ident, args) =>
val funcNonUniqueName = ident.id.replaceAll("#\\d+$", "")
funcMap.get(ident.id) match {
case Some(FuncSymbol(_, _, returnType, paramTypes, _)) =>
if args.size != paramTypes.size then
addSemanticError(callPos, s"", s"Function ${funcNonUniqueName} called with wrong # of arguments",
FunctionCallError.apply)
val argTypes = args.flatMap(findExprType(pos, _))
argTypes.zip(paramTypes).foreach { case (aType, pType) =>
if !(aType typecast Some(pType._type)) then
addSemanticError(callPos, s"",
s"Argument type ${aType.getType} does not match param type ${pType._type.getType}",
FunctionCallError.apply)
}
Some(returnType)
case None =>
addSemanticError(callPos, s"", s"Function ${funcNonUniqueName} not found",
FunctionCallError.apply)
None
}
case expr: Expr => findExprType(pos, expr)
}
def findLhsType(pos: (Int, Int), lhs: Lhs)
(using params: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Option[Type] = lhs match {
case ident: Ident => findExprType(pos, ident)
case ArrayElem(_, arrayIdent, indices) => arrayAccessType(pos, arrayIdent, indices)
case PairElem(_, selector, subLhs) =>
val subType = findLhsType(pos, subLhs)
// If it's a partial pair with PairPlaceholder in one slot, allow for PairPlaceholder
(selector, subType) match {
case (Fst, Some(PairType(PairPlaceholder, _))) => Some(PairPlaceholder)
case (Snd, Some(PairType(_, PairPlaceholder))) => Some(PairPlaceholder)
case (Fst, Some(PairType(t:Type, _))) => Some(t)
case (Snd, Some(PairType(_, t:Type))) => Some(t)
case _ => Some(AnyType)
}
}
// Handle arrayelement type extraction, checking index types.
def arrayAccessType(pos: (Int, Int), ident: Ident, indices: List[Expr])
(using params: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
errors: mutable.ListBuffer[SemanticError]): Option[Type] = findExprType(pos, ident) match {
case Some(arrType) =>
// Must all be Int
val validIndices = indices.forall(findExprType(pos, _) contains IntType)
if !validIndices then addSemanticError(pos, s"${arrType.getType}", "Array indices must be of type int",
GeneralTypeError.apply)
// Peel off array layers
val finalType = indices.foldLeft(arrType) {
case (ArrayType(inner), _) => inner
case (nonArr, _) =>
addSemanticError(pos, s"$ident", "Too many indices for array", PairAssignmentError.apply)
nonArr
}
Some(finalType)
case None | Some(_) =>
addSemanticError(pos, s"", "Invalid array access (identifier not found or not an array)",
PairAssignmentError.apply)
None
}
}

View File

@ -0,0 +1,466 @@
package wacc.frontend.syntax
import parsley.generic.{ParserBridge1, ParserBridge2, ParserBridge3, ParserBridge4/*, ParserBridge5*/}
object ast {
// ====================================================
// PROGRAM STRUCTURE
// ====================================================
// program ::= begin func* stmt end. Condenses <stmt> into List[Stmt]
case class WProgram(pos: (Int, Int), imports: List[Import], funcs: List[Func], stats: List[Stmt])
// Import statement: import string
case class Import(pos: (Int, Int), path: String)
// A function definition:
// func ::= type ident ( param-list? ) is stmt end
// param-list ::= param ( , param )*
case class Func(pos: (Int, Int), t: Type, id: Ident, args: List[Param], body: List[Stmt])
// A function parameter: param ::= type ident
case class Param(pos: (Int, Int), t: Type, id: Ident)
// ====================================================
// STATEMENTS
// ====================================================
sealed trait Stmt
// A skip statement: skip
case class Skip(pos: (Int, Int)) extends Stmt
// Variable declaration: type ident = rvalue
case class Declare(pos: (Int, Int), t: Type, ident: Ident, rhs: Rhs) extends Stmt
// Assignment: lvalue = rvalue
case class Assign(pos: (Int, Int), lhs: Lhs, rhs: Rhs) extends Stmt
// Read statement: read lvalue
case class Read(pos: (Int, Int), lhs: Lhs) extends Stmt
// Free statement: free expr
case class Free(pos: (Int, Int), expr: Expr) extends Stmt
// Return statement: return expr
case class Return(pos: (Int, Int), expr: Expr) extends Stmt
// Exit statement: exit expr
case class Exit(pos: (Int, Int), expr: Expr) extends Stmt
// Sealed trait for both prints (Used for code gen)
sealed trait PrintBase extends Stmt {
def pos: (Int, Int)
def expr: Expr
}
// Print statement: print expr
case class Print(pos: (Int, Int), expr: Expr) extends PrintBase
// Println statement: println expr
case class Println(pos: (Int, Int), expr: Expr) extends PrintBase
// If statement: if expr then stmt else stmt fi
case class If(pos: (Int, Int), expr: Expr,
thenCase: List[Stmt], elseCase: List[Stmt]) extends Stmt
// While loop: while expr do stmt done
case class While(pos: (Int, Int), expr: Expr, body: List[Stmt]) extends Stmt
// Block statement: begin stmt end
case class Block(pos: (Int, Int), stmts: List[Stmt]) extends Stmt
// stmt ; stmt is captured by List[Stmt]
// ====================================================
// LVALUE AND RVALUE
// ====================================================
// Both can be expressions
sealed trait Lhs
sealed trait Rhs
// ----------------------------------------------------
// Expressions
// ----------------------------------------------------
sealed trait Expr extends Rhs
// Literals
case class IntLiter(pos: (Int, Int), value: BigInt) extends Expr
case class BoolLiter(pos: (Int, Int), value: Boolean) extends Expr
case class CharLiter(pos: (Int, Int), value: Char) extends Expr
case class StrLiter(pos: (Int, Int), value: String) extends Expr
// The null literal (for pairs)
case class PairLiter(pos: (Int, Int)) extends Expr
// Array literal: array-liter ::= [ ( expr (, expr)* )? ]
case class ArrayLiter(pos: (Int, Int), elems: List[Expr]) extends Rhs
// newpair ( expr , expr )
case class NewPair(pos: (Int, Int), fst: Expr, snd: Expr) extends Rhs
// Function call: call ident ( arg-list? )
case class Call(pos: (Int, Int), id: Ident, args: List[Expr]) extends Rhs
// ----------------------------------------------------
// Operators
// ----------------------------------------------------
// Descending Order of Precedence, with top paragraph being most tightly binding
// Unary operators
sealed trait UnOp extends Expr {
def expr: Expr
}
case class IfExpr(pos: (Int, Int), cond: Expr, thenBranch: Expr, elseBranch: Expr) extends Expr
// Prefix
case class Not(expr: Expr) extends UnOp
case class Neg(expr: Expr) extends UnOp
case class Len(expr: Expr) extends UnOp
case class Ord(expr: Expr) extends UnOp
case class Chr(expr: Expr) extends UnOp
// Binary operators
sealed trait BinOp extends Expr {
def lhs: Expr
def rhs: Expr
}
sealed trait DivOp extends BinOp
sealed trait CompOp extends BinOp
// infix left
case class Mul(lhs: Expr, rhs: Expr) extends BinOp
case class Div(lhs: Expr, rhs: Expr) extends DivOp
case class Mod(lhs: Expr, rhs: Expr) extends DivOp
// infix left
case class Add(lhs: Expr, rhs: Expr) extends BinOp
case class Sub(lhs: Expr, rhs: Expr) extends BinOp
// infix non
case class Gt(lhs: Expr, rhs: Expr) extends CompOp
case class Geq(lhs: Expr, rhs: Expr) extends CompOp
case class Lt(lhs: Expr, rhs: Expr) extends CompOp
case class Leq(lhs: Expr, rhs: Expr) extends CompOp
// infix non
case class Eq(lhs: Expr, rhs: Expr) extends CompOp
case class Neq(lhs: Expr, rhs: Expr) extends CompOp
// infix right
case class And(lhs: Expr, rhs: Expr) extends BinOp
case class Or(lhs: Expr, rhs: Expr) extends BinOp
// ----------------------------------------------------
// LHS: Left-hand sides for assignments
// ----------------------------------------------------
// An identifier: ident
case class Ident(pos: (Int, Int), id: String) extends Lhs with Expr {
override def equals(obj: Any): Boolean = obj match {
case Ident(_, id2) => this.id == id2
case _ => false
}
override def hashCode(): Int = id.hashCode
}
// An array element access: array-elem
case class ArrayElem(pos: (Int, Int), id: Ident, indices: List[Expr]) extends Lhs with Expr
// Pair element: pair-elem
sealed trait PairSelector
case object Fst extends PairSelector
case object Snd extends PairSelector
// A pair element is both an lvalue and an expression:
// pair-elem ::= fst lvalue | snd lvalue
case class PairElem(pos: (Int, Int), selector: PairSelector, expr: Lhs) extends Lhs with Expr
// ====================================================
// TYPES
// ====================================================
// For pair types, the elements in the declarations must be either a base type or an array type
// as (Nested pairs are erased).
// A separate trait to mark types that can appear as elements of a pair.
sealed trait PairElemType {
/**
* A helper method to compare array subtypes. It returns true if x and y are exactly the same,
* or if x and y are further pair-compatible.
*/
private def arrayCompatible(x: PairElemType, y: PairElemType): Boolean =
(x == y) || ((x, y) match {
case (a: PairElemType, b: PairElemType) => a.paircast(Some(b))
})
/**
* Attempt to typecast this type into the other (Option[PairElemType]).
* For example, 'AnyType' can match anything, an Array of 'AnyType' can
* match any array, etc. If the other is None, returns false.
*/
infix def typecast(other: Option[PairElemType]): Boolean = other match {
case None => false
case Some(o) => (this, o) match {
// Pair placeholders can match a concrete pair (and vice versa)
case (PairType(_, _), PairPlaceholder) | (PairPlaceholder, PairType(_, _)) => true
// For two actual PairTypes, defer to paircast
case (PairType(_, _), PairType(_, _)) => this.paircast(other)
// An array of AnyType is compatible with any array
case (ArrayType(AnyType), ArrayType(_)) => true
// Compare arrays elementwise (including deeper pair casting)
case (ArrayType(x), ArrayType(y)) => arrayCompatible(x, y)
// "char[]" can be cast to "string"
case (ArrayType(CharType), StrType) => true
// "AnyType" matches anything, or anything can be cast to "AnyType"
case (AnyType, _) => true
case (_, AnyType) => true
// Otherwise, check for exact equality
case _ => this == o
}
}
/**
* A specialized version of casting that strictly compares pairs. If it is not
* a pair scenario or an exactly matching type, this returns false.
*/
infix def paircast(other: Option[PairElemType]): Boolean = other match {
case None => false
case Some(o) => (this, o) match {
case (PairType(f1, s1), PairType(f2, s2)) =>
(f1.paircast(Some(f2))) && (s1.paircast(Some(s2)))
// Pair placeholders can match any concrete PairType
case (PairType(_, _), PairPlaceholder) => true
case (PairPlaceholder, PairType(_, _)) => true
// Array[AnyType] can match any array
case (ArrayType(AnyType), ArrayType(_)) => true
// For two arrays, they must be exactly the same. (No 'AnyType' logic here.)
case (ArrayType(x), ArrayType(y)) => x == y
// Otherwise, they must be exactly equal
case _ => this == o
}
}
def getType: String
}
sealed trait Type extends PairElemType
case object AnyType extends Type { override def getType: String = "any"}
// Base types
case object IntType extends Type {
override def getType = "int"
}
case object BoolType extends Type { override def getType = "bool" }
case object CharType extends Type { override def getType = "char" }
case object StrType extends Type { override def getType = "string" }
// Array type: any type followed by [].
case class ArrayType(elemType: Type) extends Type {
override def getType = s"${elemType.getType}[]"
}
// I think this is needed for placeholder pair types?
// e.g: pair(int, pair) is valid but pair(pair(int, char), bool) is not
case object PairPlaceholder extends Type {
override def getType = "pair"
}
// A fullyspecified pair type.
case class PairType(fst: PairElemType, snd: PairElemType) extends Type {
// Custom equality: AnyType acts as a wildcard
override def equals(obj: Any): Boolean = obj match {
case PairType(f1, s1) =>
(fst == AnyType || fst == f1) && (snd == AnyType || snd == s1)
case _ => false
}
// Ensure that equal objects have the same hash code
override def hashCode(): Int = {
(if (fst == AnyType) 0 else fst.hashCode) ^
(if (snd == AnyType) 0 else snd.hashCode)
}
override def getType = s"pair(${fst.getType}, ${snd.getType})"
}
// ====================================================
// PARSER BRIDGES
// ====================================================
// We can have null extending base type or as a literal extending expr im not sure yet
// case object Null extends BaseType {override def getType() = "null"}
object WProgram extends ParserBridge4[(Int, Int), List[Import], List[Func], List[Stmt], WProgram]
object Import extends ParserBridge2[(Int, Int), String, Import]
object Func extends ParserBridge3[(((Int, Int), Type), Ident), List[Param], List[Stmt], Func] {
def apply(tupleOfPosTypeIdent: (((Int, Int), Type), Ident), params: List[Param], body: List[Stmt]):
Func = {
val ((pos, _type), ident) = tupleOfPosTypeIdent
new Func(pos, _type, ident, params, body)
}
override def labels: List[String] = List("function declaration")
}
object Param extends ParserBridge3[(Int, Int), Type, Ident, Param]
object Skip extends ParserBridge2[(Int, Int), Any, Skip] {
override def apply(pos: (Int, Int), _x: Any): Skip = Skip(pos)
override def labels: List[String] = List("statement")
}
object Declare extends ParserBridge4[(Int, Int), Type, Ident, Rhs, Declare] {
override def labels: List[String] = List("statement")
}
object Assign extends ParserBridge3[(Int, Int), Lhs, Rhs, Assign] {
override def labels: List[String] = List("statement")
}
object Read extends ParserBridge2[(Int, Int), Lhs, Read] {
override def labels: List[String] = List("statement")
}
object Free extends ParserBridge2[(Int, Int), Expr, Free] {
override def labels: List[String] = List("statement")
}
object Return extends ParserBridge2[(Int, Int), Expr, Return] {
override def labels: List[String] = List("statement")
}
object Exit extends ParserBridge2[(Int, Int), Expr, Exit] {
override def labels: List[String] = List("statement")
}
object Print extends ParserBridge2[(Int, Int), Expr, Print] {
override def labels: List[String] = List("statement")
}
object Println extends ParserBridge2[(Int, Int), Expr, Println] {
override def labels: List[String] = List("statement")
}
object If extends ParserBridge4[(Int, Int), Expr, List[Stmt], List[Stmt], If] {
override def labels: List[String] = List("statement")
}
object While extends ParserBridge3[(Int, Int), Expr, List[Stmt], While] {
override def labels: List[String] = List("statement")
}
object Block extends ParserBridge2[(Int, Int), List[Stmt], Block] {
override def labels: List[String] = List("statement")
}
object ArrayElem extends ParserBridge3[((Int, Int), Ident), Expr, List[Expr], ArrayElem] {
def apply(tupleOfPosIdent: ((Int, Int), Ident), head: Expr, rest: List[Expr]): ArrayElem = {
val (pos, ident) = tupleOfPosIdent
// Prepend the rest with the head
new ArrayElem(pos, ident, head +: rest)
}
}
object Ident extends ParserBridge2[(Int, Int), String, Ident] {
override def labels: List[String] = List("identifier")
}
object IntLiter extends ParserBridge2[(Int, Int), BigInt, IntLiter] {
override def labels: List[String] = List("integer literal")
}
object BoolLiter extends ParserBridge2[(Int, Int), Boolean, BoolLiter] {
override def labels: List[String] = List("boolean literal")
}
object CharLiter extends ParserBridge2[(Int, Int), Char, CharLiter] {
override def labels: List[String] = List("character literal")
}
object StrLiter extends ParserBridge2[(Int, Int), String, StrLiter] {
override def labels: List[String] = List("string literal")
}
object PairLiter extends ParserBridge2[(Int, Int), Any, PairLiter] {
override def apply(pos: (Int, Int), _x: Any): PairLiter = PairLiter(pos)
override def labels: List[String] = List("pair literal")
}
object ArrayLiter extends ParserBridge2[(Int, Int), List[Expr], ArrayLiter] {
override def labels: List[String] = List("array literal")
}
object NewPair extends ParserBridge3[(Int, Int), Expr, Expr, NewPair]
object Call extends ParserBridge3[(Int, Int), Ident, List[Expr], Call] {
override def labels: List[String] = List("function call")
}
object Not extends ParserBridge1[Expr, Not] {
override def labels: List[String] = List("expression")
}
object Neg extends ParserBridge1[Expr, Neg] {
override def labels: List[String] = List("expression")
}
object Len extends ParserBridge1[Expr, Len] {
override def labels: List[String] = List("expression")
}
object Ord extends ParserBridge1[Expr, Ord] {
override def labels: List[String] = List("expression")
}
object Chr extends ParserBridge1[Expr, Chr] {
override def labels: List[String] = List("expression")
}
object IfExpr extends ParserBridge4[(Int, Int), Expr, Expr, Expr, IfExpr] {
override def labels: List[String] = List("expression")
}
object Mul extends ParserBridge2[Expr, Expr, Mul] {
override def labels: List[String] = List("binary operator")
}
object Div extends ParserBridge2[Expr, Expr, Div] {
override def labels: List[String] = List("binary operator")
}
object Mod extends ParserBridge2[Expr, Expr, Mod] {
override def labels: List[String] = List("binary operator")
}
object Add extends ParserBridge2[Expr, Expr, Add] {
override def labels: List[String] = List("binary operator")
}
object Sub extends ParserBridge2[Expr, Expr, Sub] {
override def labels: List[String] = List("binary operator")
}
object Gt extends ParserBridge2[Expr, Expr, Gt] {
override def labels: List[String] = List("binary operator")
}
object Geq extends ParserBridge2[Expr, Expr, Geq] {
override def labels: List[String] = List("binary operator")
}
object Lt extends ParserBridge2[Expr, Expr, Lt] {
override def labels: List[String] = List("binary operator")
}
object Leq extends ParserBridge2[Expr, Expr, Leq] {
override def labels: List[String] = List("binary operator")
}
object Eq extends ParserBridge2[Expr, Expr, Eq] {
override def labels: List[String] = List("binary operator")
}
object Neq extends ParserBridge2[Expr, Expr, Neq] {
override def labels: List[String] = List("binary operator")
}
object And extends ParserBridge2[Expr, Expr, And] {
override def labels: List[String] = List("binary operator")
}
object Or extends ParserBridge2[Expr, Expr, Or] {
override def labels: List[String] = List("binary operator")
}
object PairType extends ParserBridge2[PairElemType, PairElemType, PairType]
object PairElem extends ParserBridge2[((Int, Int), PairSelector), Lhs, PairElem] {
def apply(tupleOfPosSelector: ((Int, Int), PairSelector), lhs: Lhs): PairElem = {
val (pos, selector) = tupleOfPosSelector
new PairElem(pos, selector, lhs)
}
}
object ArrayType extends ParserBridge1[Type, ArrayType]
}

View File

@ -0,0 +1,40 @@
package wacc.frontend.syntax
import ast._
import scala.concurrent._
import scala.concurrent.ExecutionContext.Implicits.global
class ImportHandler {
// given a WProgram, add all the functions from the imports to the program
def addImports(program: WProgram, optimiseFlag: Boolean = false): WProgram = {
// If there are less than 4 imports, we can just process them sequentially, as the overhead of parallelism is not worth it
// This will be evaluated at each level, so nested imports will be processed in parallel if they exceed 3 imports, even if
// the top level does not
val minInputs = 4
val importFuncs = if (!optimiseFlag || program.imports.length < minInputs) {
program.imports.map(_.path).map(addProgram).map(_.funcs).flatten
} else {
// Convert imports into a list of futures to be executed in parallel
val futures = program.imports.map(_.path).map { path =>
Future(addProgram(path))
}
// Wait for all Futures to complete
Await.result(Future.sequence(futures), scala.concurrent.duration.Duration.Inf)
.flatMap(_.funcs)
}
program.copy(funcs = program.funcs ++ importFuncs)
}
// given a file path, parse the file and return the WProgram
def addProgram(path: String): WProgram = {
parser.parseFile(path) match {
case parsley.Failure(msg) => {
println(s"\nSyntax Error $msg in file $path")
sys.exit(SYNTAX_ERROR_CODE)
}
case parsley.Success(progFromSyntax: WProgram) => addImports(progFromSyntax)
}
}
}

View File

@ -0,0 +1,69 @@
package wacc.frontend.syntax
import parsley.Parsley
import parsley.token.{Lexer, Basic}
import parsley.token.descriptions._
import parsley.token.errors._
import parsley.token.Unicode
def isEnglishLetter(c: Char): Boolean =
c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
def isEnglishLetterOrDigit(c: Char): Boolean =
isEnglishLetter(c) || c.isDigit
val errorConfig = new ErrorConfig {
override def labelSymbol: Map[String, LabelWithExplainConfig] = Map(
"[" -> Label("array index"),
"=" -> Label("assignment")
)
}
object lexer {
val excludedChars = Set('\"', '\'', '\\')
val keys = Set(
"begin", "end", "is", "skip", "read", "free", "return", "exit",
"print", "println", "if", "then", "else", "fi", "while", "do", "done",
"fst", "snd", "newpair", "call", "int", "bool", "char", "string",
"pair", "true", "false", "null", "len", "ord", "chr"
)
val ops = Set(
"=", "!", "-", "*", "/", "%", "+", "-",
">", ">=", "<", "<=", "==", "!=", "&&", "||"
)
private val desc = LexicalDesc.plain.copy(
nameDesc = NameDesc.plain.copy(
identifierStart = Basic(char => isEnglishLetter(char) || char == '_'),
identifierLetter = Basic(char => isEnglishLetterOrDigit(char) || char == '_')
),
symbolDesc = SymbolDesc.plain.copy(
hardKeywords = keys,
hardOperators = ops,
),
textDesc = TextDesc.plain.copy(
escapeSequences = EscapeDesc.plain.copy(
literals = Set('\\', '\'', '\"'),
mapping = Map("0" -> 0x00, "b" -> 0x08, "t" -> 0x09, "n" -> 0x0a, "f" -> 0x0c,
"r" -> 0x0d)
),
graphicCharacter = Unicode(char => char >= ' '.toInt && !excludedChars.contains(char.toChar)
)
),
spaceDesc = SpaceDesc.plain.copy(
lineCommentStart = "#"
),
)
private val lexer = new Lexer(desc, errorConfig)
// Define the primitive atoms of the grammar (done in terms of Scala types)
val integer = lexer.lexeme.integer.decimal32[BigInt]
val character = lexer.lexeme.character.ascii
val string = lexer.lexeme.string.ascii
val implicits = lexer.lexeme.symbol.implicits
val identifier = lexer.lexeme.names.identifier
def fully[A](p: Parsley[A]): Parsley[A] = lexer.fully(p)
}

View File

@ -0,0 +1,259 @@
package wacc.frontend.syntax
import ast._
import lexer._
import lexer.implicits.implicitSymbol
import parsley.{Parsley, Result}
import parsley.combinator.{sepBy, sepBy1, ifS}
import parsley.errors.combinator._
import parsley.expr.{precedence, Ops, Prefix, InfixL, InfixN, InfixR}
import parsley.position.pos
import Parsley.{many, atomic, pure, empty}
import scala.util.{Success, Failure}
import java.io.File
object parser {
def parseFile(input: String): Result[String, WProgram] = parser.parseFile(new File(input)) match {
case Success(res) => res
case Failure(_) => parsley.Failure("File not recognised")
}
def parse(input: String): Result[String, WProgram] = parser.parse(input)
private val parser = fully(prog)
// --------------------------
// PROGRAM
// --------------------------
// prog ::= "begin" funcs* stmts "end"
private lazy val prog: Parsley[WProgram] =
WProgram(pos, many(_import), "begin" ~> many(func), stmts <~ "end")
// --------------------------
// IMPORTS
// --------------------------
// import ::= "import" string
private lazy val _import: Parsley[Import] = Import(pos, "import" ~> string)
// --------------------------
// FUNCTIONS
// --------------------------
// func ::= <type> <ident> "(" paramList? ")" "is" stmts "end"
private lazy val func: Parsley[Func] =
Func(atomic(pos <~> _type <~> ident <~ "("), paramList <~ ")", "is" ~> stmts.filter(noReturns).explain(returnExplain) <~ "end")
// paramList ::= param ("," param)*
private lazy val paramList: Parsley[List[Param]] = sepBy(param, ",")
// param ::= <type> <ident>
private lazy val param: Parsley[Param] = Param(pos, _type, ident)
// predicate which returns true if a body returns/exits and false otherwise
private def noReturns(stmts: List[Stmt]): Boolean = stmts.lastOption match {
case Some(Return(_, _)) => true
case Some(Exit(_, _)) => true
case Some(While(_, _, body)) => noReturns(body)
case Some(If(_, _, thenStmts, elseStmts)) => noReturns(thenStmts) &&
noReturns(elseStmts)
case Some(Block(_, stats)) => noReturns(stats)
case _ => false
}
// --------------------------
// IDENTIFIERS
// --------------------------
private lazy val ident: Parsley[Ident] = Ident(pos, identifier)
private lazy val intLiter: Parsley[IntLiter] = IntLiter(pos, integer)
private lazy val boolLiter: Parsley[BoolLiter] =
BoolLiter(pos, "true" ~> pure(true) | "false" ~> pure(false))
private lazy val charLiter: Parsley[CharLiter] = CharLiter(pos, character)
private lazy val strLiter: Parsley[StrLiter] = StrLiter(pos, string)
private lazy val pairLiter: Parsley[PairLiter] = PairLiter(pos, "null")
private lazy val ifExpr: Parsley[IfExpr] =
IfExpr(pos, "if" ~> expr, "then" ~> expr, ("else"| _semiCheck) ~> expr <~ "fi")
// --------------------------
// TYPES
// --------------------------
// simpleType ::= baseType | pairType | "(" _type ")"
// created to break the left recursion in type
private lazy val simpleType: Parsley[Type] =
baseType
| pairType
| ("(" ~> _type <~ ")")
// baseType ::= "int" | "bool" | "char" | "string"
private lazy val baseType: Parsley[Type] =
("int" as IntType)
| ("bool" as BoolType)
| ("char" as CharType)
| ("string" as StrType)
// pairType ::= "pair" "(" pairElemType "," pairElemType ")"
private lazy val pairType: Parsley[Type] =
PairType("pair" ~> ("(" ~> pairElemType), "," ~> pairElemType <~ ")")
// arraySuffix parses "[]" and returns a function wrapping a type in ArrayType
private lazy val arraySuffix: Parsley[Type => Type] =
ArrayType from ("[" ~> "]")
// _type ::= simpleType (arraySuffix)*
private lazy val _type: Parsley[Type] =
simpleType.flatMap {t =>
many(arraySuffix).map(fs => fs.foldLeft(t)((acc, f) => f(acc)))
}
// pairElemType: a base (or array) type for pair elements, or the literal "pair" for erasure
private lazy val pairElemType: Parsley[PairElemType] =
(baseType.flatMap {t =>
many(arraySuffix).map(fs => fs.foldLeft(t)((acc, f) => f(acc)))
})
| ("pair" as PairPlaceholder)
// --------------------------
// EXPRESSIONS
// --------------------------
// Expression parser with precedence handling
private lazy val expr: Parsley[Expr] =
precedence(
(intLiter
| boolLiter
| charLiter
| strLiter
| pairLiter
| arrayElem
| ident
| ifExpr
| ("(" ~> expr <~ ")")).label("expression").explain(expressionExplain))
(
// If the token is an integer, return the empty combinator to escape this case so that
// it will be handled elsewhere
// Otherwise, wrap the expression (which is not a pure integer) in the Neg object
Ops(Prefix)(atomic(ifS(atomic(integer.hide) ~> pure(true)| pure(false),
empty, Neg from "-"))),
Ops(Prefix)(Not from "!"),
Ops(Prefix)(Len from "len"),
Ops(Prefix)(Ord from "ord"),
Ops(Prefix)(Chr from "chr"),
Ops(InfixL)(Mul from "*"),
Ops(InfixL)(Mod from "%"),
Ops(InfixL)(Div from "/"),
Ops(InfixL)(Add from "+"),
Ops(InfixL)(Sub from "-"),
Ops(InfixN)(Gt from ">"),
Ops(InfixN)(Geq from ">="),
Ops(InfixN)(Lt from "<"),
Ops(InfixN)(Leq from "<="),
Ops(InfixN)(Eq from "=="),
Ops(InfixN)(Neq from "!="),
Ops(InfixR)(And from "&&"),
Ops(InfixR)(Or from "||")
)
// --------------------------
// STATEMENTS
// --------------------------
private lazy val stmt: Parsley[Stmt] =
skipStmt
| declareStmt
| assignStmt
| readStmt
| freeStmt
| returnStmt
| exitStmt
| printStmt
| printlnStmt
| ifStmt
| whileStmt
| blockStmt
// stmts ::= stmt (";" stmt)*
private lazy val stmts: Parsley[List[Stmt]] = sepBy1(stmt, ";")
// "skip" statement.
private lazy val skipStmt: Parsley[Skip] = Skip(pos, "skip")
// Declaration: <type> <ident> "=" rVal
private lazy val declareStmt: Parsley[Declare] = Declare(pos, _type, ident <~ "=", rVal)
// Assignment: lVal "=" rVal
private lazy val assignStmt: Parsley[Assign] = Assign(pos, lVal <~ "=", rVal)
// Read statement: "read" lVal
private lazy val readStmt: Parsley[Read] = Read(pos, "read" ~> lVal)
// Free statement: "free" expr
private lazy val freeStmt: Parsley[Free] = Free(pos, "free" ~> expr)
// Return statement: "return" expr
private lazy val returnStmt: Parsley[Return] = Return(pos, "return" ~> expr)
// Exit statement: "exit" expr
private lazy val exitStmt: Parsley[Exit] = Exit(pos, "exit" ~> expr)
// Print statement: "print" expr
private lazy val printStmt: Parsley[Print] = Print(pos, "print" ~> expr)
// Println statement: "println" expr
private lazy val printlnStmt: Parsley[Println] = Println(pos, "println" ~> expr)
// If statement: "if" expr "then" stmts "else" stmts "fi"
private lazy val whileStmt: Parsley[While] = While(pos, "while" ~> expr, "do" ~> stmts <~ "done")
// While statement: "while" expr "do" stmts "done"
private lazy val ifStmt: Parsley[If] =
If(pos, "if" ~> expr, "then" ~> stmts, ("else"| _semiCheck) ~> stmts <~ "fi")
// Block statement: "begin" stmts "end"
private lazy val blockStmt: Parsley[Block] = Block(pos, "begin" ~> stmts <~ "end")
// --------------------------
// L-VALUES and R-VALUES
// --------------------------
// lVal ::= arrayElem | pairElem | lident
private lazy val lVal: Parsley[Lhs] =
arrayElem
| pairElem
| ident
// arrayElem ::= lident ("[" expr "]")+
private lazy val arrayElem: Parsley[ArrayElem] = ArrayElem(atomic(pos <~> ident <~ "["), expr <~ "]", many("[" ~> expr <~ "]"))
// pairElem ::= ("fst" | "snd") lVal
private lazy val pairElem: Parsley[PairElem] =
PairElem(atomic(pos <~> (("fst" as Fst) | ("snd" as Snd))), lVal)
// rVal ::= expr | arrayLiter | newPair | pairElem | call
private lazy val rVal: Parsley[Rhs] =
expr
| arrayLiter
| newPair
| pairElem
| call
// arrayLiter ::= "[" exprs "]"
private lazy val arrayLiter: Parsley[ArrayLiter] =
ArrayLiter(pos, "[" ~> exprs <~ "]")
// newPair ::= "newpair" "(" expr "," expr ")"
private lazy val newPair: Parsley[NewPair] =
NewPair(pos, "newpair" ~> ("(" ~> expr), "," ~> expr <~ ")")
// call ::= "call" lident "(" exprs ")"
private lazy val call: Parsley[Call] = Call(pos, "call" ~> ident, "(" ~> exprs <~ ")")
// exprs ::= (expr ("," expr)*)?
private lazy val exprs: Parsley[List[Expr]] = sepBy(expr, ",")
}

View File

@ -0,0 +1,17 @@
package wacc.frontend.syntax
import parsley.character.char
import parsley.errors.patterns.VerifiedErrors
val SYNTAX_ERROR_CODE = 100
val _semiCheck = char(';').verifiedExplain("semi-colons cannot be written between `if` and `else`")
// Defining explanations for constructs that yield the same error explanation message
val expressionExplain = "expressions may start with integer, string, character or boolean\n" +
" literals; identifiers; unary operators; null; or parentheses.\n" +
" In addition, expressions may contain array indexing operations; and\n" +
" comparison, logical, and arithmetic operators"
val bodyEndExplain = "all programs must be enclosed in begin ... end"
val returnExplain = "Missing a return on all exit paths"

114
src/test/wacc/codeGen.scala Normal file
View File

@ -0,0 +1,114 @@
import org.scalatest.flatspec.AnyFlatSpec
class codeGenTest extends CodeGeneratorTestSuite {
/* Wrapper for uniform testing */
def test(input: os.Path): (Boolean, String) = {
super.test(input, None, None)
}
// IGNORE SET: Contains files/directories to ignore in testing
val ignoreSet: Set[os.Path] = Set()
// ----------------------------------------------------
// Testing programs with syntax errors
// ----------------------------------------------------
behavior of "TESTING CODE GEN IN syntaxErr"
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
// ----------------------------------------------------
// Testing programs with semantic errors
// ----------------------------------------------------
behavior of "TESTING CODE GEN IN semanticErr"
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticArray))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticExit))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticIf))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticIO))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticRead))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticScope))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
// ----------------------------------------------------
// Testing valid programs
// ----------------------------------------------------
behavior of "TESTING CODE GEN IN valid/basic/exit"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validBasicExit))
behavior of "TESTING CODE GEN IN valid/basic/skip"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validBasicSkip))
behavior of "TESTING CODE GEN IN valid/expressions"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validExpression))
behavior of "TESTING CODE GEN IN valid/function/nested_functions"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
behavior of "TESTING CODE GEN IN valid/function/simple_functions"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
behavior of "TESTING CODE GEN IN valid/function/overload_functions"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
behavior of "TESTING CODE GEN IN valid/if"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIf))
behavior of "TESTING CODE GEN IN valid/IO/print"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIOPrint))
behavior of "TESTING CODE GEN IN valid/IO/read"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIORead))
behavior of "TESTING CODE GEN IN valid/IO/special"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIOSpecial))
behavior of "TESTING CODE GEN IN valid/pairs"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validPair))
behavior of "TESTING CODE GEN IN valid/runtimeErr/arrayOutOfBounds"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
behavior of "TESTING CODE GEN IN valid/runtimeErr/badChar"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEBadChar))
behavior of "TESTING CODE GEN IN valid/runtimeErr/divideByZero"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
behavior of "TESTING CODE GEN IN valid/runtimeErr/integerOverflow"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
behavior of "TESTING CODE GEN IN valid/runtimeErr/nullDereference"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTENullDereference))
behavior of "TESTING CODE GEN IN valid/scope"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validScope))
behavior of "TESTING CODE GEN IN valid/sequence"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validSequence))
behavior of "TESTING CODE GEN IN valid/variables"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validVariables))
behavior of "TESTING CODE GEN IN valid/while"
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validWhile))
}

View File

@ -0,0 +1,194 @@
import org.scalatest.flatspec.AnyFlatSpec
import parsley._
import scala.util.{Try, Success, Failure}
import wacc.backend.CodeGenerator
import wacc.extension.Peephole
import wacc.frontend.semantic.analyse
import wacc.frontend.syntax.ast.WProgram
import wacc.frontend.syntax.ImportHandler
import java.time.Instant
// This is the trait for all code generating tests (CodeGen, Imports, Peephole)
trait CodeGeneratorTestSuite extends ParserTestSuite {
/**
* Returns the success of generating a wacc file
*
* @param input Path to wacc file
* @return either (true, "") on a successful gen or (false, msg) on an unsuccessful gen
*/
def test(input: os.Path, peepholeOpt: Option[Peephole], importHandlerOpt: Option[ImportHandler], printParallelTests: Boolean = false): (Boolean, String) = {
parseFile(input) match {
case parsley.Success(prog: WProgram) => {
// Adds the program to be analysed based on whether the import handler exists or not
val progToAnalyse = importHandlerOpt match {
case Some(importHandler) =>
if (printParallelTests) {
var average = 0L
val testNum = 100
println(s"Starting Non-Parallel import processing for ${input.last} over $testNum tests")
for (i <- 1 to testNum) {
val start = Instant.now()
importHandler.addImports(prog)
val end = Instant.now()
average += java.time.Duration.between(start, end).toNanos
}
println(s"Non-Parallel processing took: ${average/testNum} ns on average")
average = 0L
println(s"Starting Parallel import processing for ${input.last}")
for (i <- 1 to testNum) {
val start = Instant.now()
importHandler.addImports(prog, optimiseFlag = true)
val end = Instant.now()
average += java.time.Duration.between(start, end).toNanos
}
println(s"Parallel processing took: ${average/testNum} ns")
}
importHandler.addImports(prog, optimiseFlag = printParallelTests)
case None => prog
}
// Analyse the program and perform further checks after
analyse(progToAnalyse) match {
case Left(res) =>
val fileName = input.last
val outputFile = s"${fileName.replace(".wacc", ".s")}"
val outputPath = os.pwd / "src" / "test" / "wacc" / "waccOutputs" / outputFile
// Make the outputPath's folder if it doesn't already exist
os.makeDir.all(outputPath / os.up)
val (prog, fEnv, mEnv) = res
// Create the code generator and peephole instances
val codeGenerator = new CodeGenerator()
// Generate the asm without any possible peephole optimization
val unOptimizedAsm = codeGenerator.generate(prog, fEnv, mEnv)
val asmCode = peepholeOpt match {
case Some(peephole) => peephole.optimize(unOptimizedAsm).map(_.emit).mkString("\n")
case None => unOptimizedAsm.map(_.emit).mkString("\n")
}
// Writes to the folder location given that it doesn't exist or contains a different asm
if (!os.exists(outputPath) || os.read(outputPath) != asmCode) {
os.write.over(outputPath, asmCode)
}
// Run checkAsm, then delete the folder containing the outputs regardless of the outcome
val result =
try {
checkAsm(input, outputPath)
} finally {
os.remove.all(outputPath / os.up)
}
result
case Right(errs) =>
val fileName = input.last
(false, errs.map(err => err.getMessage(fileName, input.toString)).mkString("\n"))
}
}
case parsley.Failure(msg) => (false, msg)
}
}
def checkAsm(waccPath: os.Path, asmPath: os.Path): (Boolean, String) = {
// Read the WACC file lines
val lines = os.read.lines(waccPath)
// Extract the expected exit code from a "# Exit:" block (default to 0 if not provided)
val expectedExit: Int = {
val maybeIndex = lines.indexWhere(_.trim == "# Exit:")
if (maybeIndex != -1 && lines.size > maybeIndex + 1)
Try(lines(maybeIndex + 1).stripPrefix("#").trim.toInt).getOrElse(0)
else 0
}
// Extract expected output lines from the "# Output:" block
// We take all lines after "# Output:" until we reach a blank line
val expectedOutput: List[String] = {
val startIndex = lines.indexWhere(_.trim == "# Output:")
if (startIndex != -1)
lines.drop(startIndex + 1)
.takeWhile(line => line.stripPrefix("#").trim.nonEmpty)
.map(_.stripPrefix("#").trim)
.toList
else Nil
}
// Extract input (if any) from a line starting with "# Input:"
val maybeInput: Option[String] =
lines.find(_.trim.startsWith("# Input:")).map(_.stripPrefix("# Input:").trim)
// Derive the assembly file name and the executable name
val asmFileName = asmPath.last
val executableName = asmFileName.stripSuffix(".s")
// Get the directory containing the assembly file (this will be the working directory)
val asmDir = asmPath / os.up
// Build the run command using qemu-aarch64 properly
val qemuString = s"qemu-aarch64 -L /usr/aarch64-linux-gnu/ ./$executableName"
val runCommand = maybeInput match {
case Some(input) =>
// Escape any single quotes to avoid shell issues
val safeInput = input.replace("'", "'\\''")
s"echo \"$safeInput\" | $qemuString"
case None =>
s"echo '' | $qemuString"
}
// Build the compile command using the cross compiler
val compileCommand = s"aarch64-linux-gnu-gcc -o $executableName -z noexecstack -march=armv8-a $asmFileName"
val execCommand = s"$compileCommand && $runCommand"
// Execute the command using the shell, setting cwd to the asm directory
val resultTry = Try(os.proc("sh", "-c", execCommand)
.call(cwd = asmDir, stdout = os.Pipe, stderr = os.Pipe, check = false))
resultTry match {
case Failure(e) => (false, s"Exception while running command: ${e.getMessage}")
case Success(result) =>
val rawOutput = result.out.text().trim.linesIterator.toList
val actualOutput = normaliseOutput(rawOutput)
val actualExit = result.exitCode
var errors = List.empty[String]
if (actualExit != expectedExit) {
errors ::= s"Expected exit code $expectedExit but got $actualExit."
}
if (actualOutput != expectedOutput) {
errors ::= s"Expected output:\n${expectedOutput.mkString("\n")}\nBut got:\n${actualOutput.mkString("\n")}"
}
if (errors.isEmpty) {
(true, "")
} else {
(false, errors.reverse.mkString("\n"))
}
}
}
/** Normalises the output after executing generated code
* replacing addresses with #addrs#
* replacing runtime_errors with #runtime_error#
*
* Returns the normalised Output as a list of strings
*/
def normaliseOutput(rawOutput: List[String]): List[String] = {
val addressPattern = "0x[0-9A-Fa-f]+".r
val rtePattern = "^fatal error:.*$".r
val outputWithAddrs = rawOutput.map { line =>
addressPattern.replaceAllIn(line, "#addrs#")
}
val outputWithRTEs = outputWithAddrs.map { line =>
rtePattern.replaceAllIn(line, "#runtime_error#")
}
outputWithRTEs
}
}

121
src/test/wacc/imports.scala Normal file
View File

@ -0,0 +1,121 @@
import org.scalatest.flatspec.AnyFlatSpec
import wacc.frontend.syntax.ImportHandler
class importsTest extends CodeGeneratorTestSuite {
/* Wrapper for uniform testing */
def test(input: os.Path): (Boolean, String) = {
// Create the importHandler instance
val importHandler = ImportHandler()
super.test(input, None, Some(importHandler))
}
// IGNORE SET: Contains files/directories to ignore in testing
val ignoreSet: Set[os.Path] = Set()
// ----------------------------------------------------
// Testing programs with syntax errors
// ----------------------------------------------------
behavior of "TESTING IMPORT HANDLER IN syntaxErr"
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
// ----------------------------------------------------
// Testing programs with semantic errors
// ----------------------------------------------------
behavior of "TESTING IMPORT HANDLER IN semanticErr"
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticArray))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticExit))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticIf))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticIO))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticRead))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticScope))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
// ----------------------------------------------------
// Testing valid programs
// ----------------------------------------------------
behavior of "TESTING IMPORT HANDLER IN valid/basic/exit"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validBasicExit))
behavior of "TESTING IMPORT HANDLER IN valid/basic/skip"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validBasicSkip))
behavior of "TESTING IMPORT HANDLER IN valid/expressions"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validExpression))
behavior of "TESTING IMPORT HANDLER IN valid/function/nested_functions"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
behavior of "TESTING IMPORT HANDLER IN valid/function/simple_functions"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
behavior of "TESTING IMPORT HANDLER IN valid/function/overload_functions"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
behavior of "TESTING IMPORT HANDLER IN valid/function/imported_functions"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionImports))
behavior of "TESTING IMPORT HANDLER IN valid/if"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIf))
behavior of "TESTING IMPORT HANDLER IN valid/IO/print"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIOPrint))
behavior of "TESTING IMPORT HANDLER IN valid/IO/read"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIORead))
behavior of "TESTING IMPORT HANDLER IN valid/IO/special"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIOSpecial))
behavior of "TESTING IMPORT HANDLER IN valid/pairs"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validPair))
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/arrayOutOfBounds"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/badChar"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEBadChar))
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/divideByZero"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/integerOverflow"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/nullDereference"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTENullDereference))
behavior of "TESTING IMPORT HANDLER IN valid/scope"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validScope))
behavior of "TESTING IMPORT HANDLER IN valid/sequence"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validSequence))
behavior of "TESTING IMPORT HANDLER IN valid/variables"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validVariables))
behavior of "TESTING IMPORT HANDLER IN valid/while"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validWhile))
}

View File

@ -0,0 +1,23 @@
import org.scalatest.flatspec.AnyFlatSpec
import wacc.frontend.syntax.ImportHandler
class importsParallelTest extends CodeGeneratorTestSuite {
/* Wrapper for uniform testing */
def test(input: os.Path): (Boolean, String) = {
// Create the importHandler instance
val importHandler = ImportHandler()
super.test(input, None, Some(importHandler), printParallelTests = false)
}
// IGNORE SET: Contains files/directories to ignore in testing
val ignoreSet: Set[os.Path] = Set()
// ----------------------------------------------------
// Testing programs with syntax errors
// ----------------------------------------------------
behavior of "TESTING PARALLEL IMPORT HANDLER IN valid/function/imported_functions"
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionImports))
}

171
src/test/wacc/parser.scala Normal file
View File

@ -0,0 +1,171 @@
import org.scalatest.flatspec.AnyFlatSpec
import parsley.{Success, Failure}
class ParserTests extends ParserTestSuite {
/**
* Returns the success of parsing a wacc file
*
* @param input Path to wacc file
* @return either (true, "") on a successful parse or (false, msg) on an unsuccessful parse
*/
def test(input: os.Path): (Boolean, String) = {
parseFile(input) match {
case Success(_) => (true, "")
case Failure(msg) => (false, msg)
}
}
// IGNORE SET: Contains files/directories to ignore in testing
val ignoreSet: Set[os.Path] = Set()
// ----------------------------------------------------
// Testing programs with syntax errors
// ----------------------------------------------------
behavior of "TESTING PARSER IN syntaxErr/array"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
behavior of "TESTING PARSER IN syntaxErr/basic"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
behavior of "TESTING PARSER IN syntaxErr/expressions"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
behavior of "TESTING PARSER IN syntaxErr/function"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
behavior of "TESTING PARSER IN syntaxErr/if"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
behavior of "TESTING PARSER IN syntaxErr/literals"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
behavior of "TESTING PARSER IN syntaxErr/pairs"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
behavior of "TESTING PARSER IN syntaxErr/print"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
behavior of "TESTING PARSER IN syntaxErr/sequence"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
behavior of "TESTING PARSER IN syntaxErr/variables"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
behavior of "TESTING PARSER IN syntaxErr/while"
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
// ----------------------------------------------------
// Testing syntatically valid programs
// ----------------------------------------------------
behavior of "TESTING PARSER IN valid/advanced"
genTest("Parser", ignoreSet, test, "succeed", os.list(validAdvanced))
behavior of "TESTING PARSER IN valid/array"
genTest("Parser", ignoreSet, test, "succeed", os.list(validArray))
behavior of "TESTING PARSER IN valid/basic/exit"
genTest("Parser", ignoreSet, test, "succeed", os.list(validBasicExit))
behavior of "TESTING PARSER IN valid/basic/skip"
genTest("Parser", ignoreSet, test, "succeed", os.list(validBasicSkip))
behavior of "TESTING PARSER IN valid/expressions"
genTest("Parser", ignoreSet, test, "succeed", os.list(validExpression))
behavior of "TESTING PARSER IN valid/function/nested_functions"
genTest("Parser", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
behavior of "TESTING PARSER IN valid/function/simple_functions"
genTest("Parser", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
behavior of "TESTING PARSER IN valid/function/overload_functions"
genTest("Parser", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
behavior of "TESTING PARSER IN valid/if"
genTest("Parser", ignoreSet, test, "succeed", os.list(validIf))
behavior of "TESTING PARSER IN valid/IO/print"
genTest("Parser", ignoreSet, test, "succeed", os.list(validIOPrint))
behavior of "TESTING PARSER IN valid/IO/read"
genTest("Parser", ignoreSet, test, "succeed", os.list(validIORead))
behavior of "TESTING PARSER IN valid/IO/special"
genTest("Parser", ignoreSet, test, "succeed", os.list(validIOSpecial))
behavior of "TESTING PARSER IN valid/pairs"
genTest("Parser", ignoreSet, test, "succeed", os.list(validPair))
behavior of "TESTING PARSER IN valid/runtimeErr/arrayOutOfBounds"
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
behavior of "TESTING PARSER IN valid/runtimeErr/badChar"
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEBadChar))
behavior of "TESTING PARSER IN valid/runtimeErr/divideByZero"
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
behavior of "TESTING PARSER IN valid/runtimeErr/integerOverflow"
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
behavior of "TESTING PARSER IN valid/runtimeErr/nullDereference"
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTENullDereference))
behavior of "TESTING PARSER IN valid/scope"
genTest("Parser", ignoreSet, test, "succeed", os.list(validScope))
behavior of "TESTING PARSER IN valid/sequence"
genTest("Parser", ignoreSet, test, "succeed", os.list(validSequence))
behavior of "TESTING PARSER IN valid/variables"
genTest("Parser", ignoreSet, test, "succeed", os.list(validVariables))
behavior of "TESTING PARSER IN valid/while"
genTest("Parser", ignoreSet, test, "succeed", os.list(validWhile))
// ----------------------------------------------------
// Testing syntatically valid programs (These are semantically invalid)
// ----------------------------------------------------
behavior of "TESTING PARSER IN semanticErr/array"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticArray))
behavior of "TESTING PARSER IN semanticErr/exit"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticExit))
behavior of "TESTING PARSER IN semanticErr/expressions"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticExpressions))
behavior of "TESTING PARSER IN semanticErr/function"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticFunction))
behavior of "TESTING PARSER IN semanticErr/if"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticIf))
behavior of "TESTING PARSER IN semanticErr/IO"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticIO))
behavior of "TESTING PARSER IN semanticErr/multiple"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticMultiple))
behavior of "TESTING PARSER IN semanticErr/pairs"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticPairs))
behavior of "TESTING PARSER IN semanticErr/print"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticPrint))
behavior of "TESTING PARSER IN semanticErr/read"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticRead))
behavior of "TESTING PARSER IN semanticErr/scope"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticScope))
behavior of "TESTING PARSER IN semanticErr/variables"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticVariables))
behavior of "TESTING PARSER IN semanticErr/while"
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticWhile))
}

View File

@ -0,0 +1,119 @@
import org.scalatest.flatspec.AnyFlatSpec
import wacc.extension.Peephole
// These tests are here to ensure that all previous tests still pass after the peephole optimization is applied
class peepholeTests extends CodeGeneratorTestSuite {
/* Wrapper for uniform testing */
def test(input: os.Path): (Boolean, String) = {
// Create the peephole instance
val peephole = Peephole()
super.test(input, Some(peephole), None)
}
// IGNORE SET: Contains files/directories to ignore in testing
val ignoreSet: Set[os.Path] = Set()
// ----------------------------------------------------
// Testing programs with syntax errors
// ----------------------------------------------------
behavior of "TESTING PEEPHOLE IN syntaxErr"
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
// ----------------------------------------------------
// Testing programs with semantic errors
// ----------------------------------------------------
behavior of "TESTING PEEPHOLE IN semanticErr"
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticArray))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticExit))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticIf))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticIO))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticRead))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticScope))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
// ----------------------------------------------------
// Testing valid programs
// ----------------------------------------------------
behavior of "TESTING PEEPHOLE IN valid/basic/exit"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validBasicExit))
behavior of "TESTING PEEPHOLE IN valid/basic/skip"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validBasicSkip))
behavior of "TESTING PEEPHOLE IN valid/expressions"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validExpression))
behavior of "TESTING PEEPHOLE IN valid/function/nested_functions"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
behavior of "TESTING PEEPHOLE IN valid/function/simple_functions"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
behavior of "TESTING PEEPHOLE IN valid/function/overload_functions"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
behavior of "TESTING PEEPHOLE IN valid/if"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIf))
behavior of "TESTING PEEPHOLE IN valid/IO/print"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIOPrint))
behavior of "TESTING PEEPHOLE IN valid/IO/read"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIORead))
behavior of "TESTING PEEPHOLE IN valid/IO/special"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIOSpecial))
behavior of "TESTING PEEPHOLE IN valid/pairs"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validPair))
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/arrayOutOfBounds"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/badChar"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEBadChar))
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/divideByZero"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/integerOverflow"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/nullDereference"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTENullDereference))
behavior of "TESTING PEEPHOLE IN valid/scope"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validScope))
behavior of "TESTING PEEPHOLE IN valid/sequence"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validSequence))
behavior of "TESTING PEEPHOLE IN valid/variables"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validVariables))
behavior of "TESTING PEEPHOLE IN valid/while"
genTest("Peephole", ignoreSet, test, "succeed", os.list(validWhile))
}

View File

@ -0,0 +1,143 @@
import org.scalatest.flatspec.AnyFlatSpec
import parsley.{Success, Failure}
import wacc.frontend.syntax.ast.WProgram
import wacc.frontend.semantic.analyse
class semanticAnalyserTests extends ParserTestSuite {
/**
* Returns the success of parsing a wacc file
*
* @param input Path to wacc file
* @return either (true, "") on a successful analysis or (false, msg) on an unsuccessful analysis
*/
def test(input: os.Path): (Boolean, String) = {
parseFile(input) match {
case Success(x: WProgram) => {
analyse(x) match {
case Left(prog) => (true, "")
case Right(errs) =>
val fileName = input.last
(false, errs.map(err => err.getMessage(fileName, input.toString)).mkString("\n"))
}
}
case Failure(msg) => (false, "OOPS. Looks like the parser failed :/")
}
}
// IGNORE SET: Contains files/directories to ignore in testing
val ignoreSet: Set[os.Path] = Set()
// ----------------------------------------------------
// Testing programs with semantic errors
// ----------------------------------------------------
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/array"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticArray))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/exit"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticExit))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/expressions"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/function"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/if"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticIf))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/IO"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticIO))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/multiple"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/pairs"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/print"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/read"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticRead))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/scope"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticScope))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/variables"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/while"
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
// ----------------------------------------------------
// Testing semantically valid programs
// ----------------------------------------------------
behavior of "TESTING SEMANTIC ANALYSER IN valid/advanced"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validAdvanced))
behavior of "TESTING SEMANTIC ANALYSER IN valid/array"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validArray))
behavior of "TESTING SEMANTIC ANALYSER IN valid/basic/exit"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validBasicExit))
behavior of "TESTING SEMANTIC ANALYSER IN valid/basic/skip"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validBasicSkip))
behavior of "TESTING SEMANTIC ANALYSER IN valid/expressions"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validExpression))
behavior of "TESTING SEMANTIC ANALYSER IN valid/function/nested_functions"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
behavior of "TESTING SEMANTIC ANALYSER IN valid/function/simple_functions"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
behavior of "TESTING SEMANTIC ANALYSER IN valid/function/overload_functions"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
behavior of "TESTING SEMANTIC ANALYSER IN valid/if"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIf))
behavior of "TESTING SEMANTIC ANALYSER IN valid/IO/print"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIOPrint))
behavior of "TESTING SEMANTIC ANALYSER IN valid/IO/read"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIORead))
behavior of "TESTING SEMANTIC ANALYSER IN valid/IO/special"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIOSpecial))
behavior of "TESTING SEMANTIC ANALYSER IN valid/pairs"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validPair))
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/arrayOutOfBounds"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/badChar"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEBadChar))
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/divideByZero"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/integerOverflow"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/nullDereference"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTENullDereference))
behavior of "TESTING SEMANTIC ANALYSER IN valid/scope"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validScope))
behavior of "TESTING SEMANTIC ANALYSER IN valid/sequence"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validSequence))
behavior of "TESTING SEMANTIC ANALYSER IN valid/variables"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validVariables))
behavior of "TESTING SEMANTIC ANALYSER IN valid/while"
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validWhile))
}

View File

@ -0,0 +1,131 @@
import org.scalatest.flatspec.AnyFlatSpec
import parsley.Result
import wacc.frontend.syntax.{parser, ast}
import wacc.frontend.syntax.ast._
// ----------------------------------------------------
// SETUP WITH PATHS TO THE DIFFERENT DIRECTORIES
// ----------------------------------------------------
val basePath = os.pwd / "src" / "test" / "wacc" / "waccPrograms"
val invalidSyntaxPath = basePath / "syntaxErr"
val invalidSemanticPath = basePath / "semanticErr"
val validPath = basePath / "valid"
// Invalid Syntax test paths
val invalidSyntaxArray = invalidSyntaxPath / "array"
val invalidSyntaxBasic = invalidSyntaxPath / "basic"
val invalidSyntaxExpressions = invalidSyntaxPath / "expressions"
val invalidSyntaxFunction = invalidSyntaxPath / "function"
val invalidSyntaxIf = invalidSyntaxPath / "if"
val invalidSyntaxLiteral = invalidSyntaxPath / "literals"
val invalidSyntaxPairs = invalidSyntaxPath / "pairs"
val invalidSyntaxPrint = invalidSyntaxPath / "print"
val invalidSyntaxSequence = invalidSyntaxPath / "sequence"
val invalidSyntaxVariables = invalidSyntaxPath / "variables"
val invalidSyntaxWhile = invalidSyntaxPath / "while"
// Invalid Semantic test paths
val invalidSemanticArray = invalidSemanticPath / "array"
val invalidSemanticExit = invalidSemanticPath / "exit"
val invalidSemanticExpressions = invalidSemanticPath / "expressions"
val invalidSemanticFunction = invalidSemanticPath / "function"
val invalidSemanticIf = invalidSemanticPath / "if"
val invalidSemanticIO = invalidSemanticPath / "IO"
val invalidSemanticMultiple = invalidSemanticPath / "multiple"
val invalidSemanticPairs = invalidSemanticPath / "pairs"
val invalidSemanticPrint = invalidSemanticPath / "print"
val invalidSemanticRead = invalidSemanticPath / "read"
val invalidSemanticScope = invalidSemanticPath / "scope"
val invalidSemanticVariables = invalidSemanticPath / "variables"
val invalidSemanticWhile = invalidSemanticPath / "while"
// Valid test paths
val validAdvanced = validPath / "advanced"
val validArray = validPath / "array"
val validBasicExit = validPath / "basic" / "exit"
val validBasicSkip = validPath / "basic" / "skip"
val validExpression = validPath / "expressions"
val validFunctionNestFuns = validPath / "function" / "nested_functions"
val validFunctionSimpFuns = validPath / "function" / "simple_functions"
val validFunctionOverFuns = validPath / "function" / "overload_functions"
val validFunctionImports = validPath / "function" / "imported_functions"
val validIf = validPath / "if"
val validIOSpecial = validPath / "IO" / "special"
val validIOPrint = validPath / "IO" / "print"
val validIORead = validPath / "IO" / "read"
val validPair = validPath / "pairs"
val validRTEArrOOB = validPath / "runtimeErr" / "arrayOutOfBounds"
val validRTEBadChar = validPath / "runtimeErr" / "badChar"
val validRTEDivByZero = validPath / "runtimeErr" / "divideByZero"
val validRTEIntOverflow = validPath / "runtimeErr" / "integerOverflow"
val validRTENullDereference = validPath / "runtimeErr" / "nullDereference"
val validScope = validPath / "scope"
val validSequence = validPath / "sequence"
val validVariables = validPath / "variables"
val validWhile = validPath / "while"
// ----------------------------------------------------
// HELPER FUNCTIONS
// ----------------------------------------------------
trait ParserTestSuite extends AnyFlatSpec {
/**
* Returns parsed wacc file
*
* @param input Path to wacc file
* @return parsed wacc program
*/
def parseFile(input: os.Path): Result[String, WProgram] = {
val program = os.read(input)
parser.parse(program)
}
/**
* Generates test cases for each file in `files`
* If the file (or directory name) appears in `dontIgnoreSet` or doesn't appear in `ignoreSet`,
* then the test will be ran.
*
* @param analyser
* @param testFunc
* @param res String indicating if the test should "fail" or "succeed"
* @param files Sequence of files within the directory of interest
*/
def genTest(analyser: String,
ignoreSet: Set[os.Path],
testFunc: (os.Path) => (Boolean, String),
res: String,
files: IndexedSeq[os.Path]): Unit = {
for (file <- files) {
val fileName = file.last
// --------------------------
// IGNORING TESTS
// --------------------------
if (ignoreSet.exists(file.startsWith)) {
ignore should s"$res with $fileName" in {
// *** test is ignored ***
}
// --------------------------
// TESTING REMAINING TESTS
// --------------------------
} else {
val (verdict, msg) = testFunc(file)
val errorMessage = s"Expected $analyser to $res. $analyser returns:\n$msg"
if (res == "fail") {
it should s"$res with $fileName" in {
assert(verdict == false, errorMessage)
}
} else {
it should s"$res with $fileName" in {
assert(verdict == true, errorMessage)
}
}
}
}
}
}

View File

@ -0,0 +1,15 @@
# attempt to read into a boolean variable
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
bool b = true ;
read b ;
println b
end

View File

@ -0,0 +1,16 @@
# Attempting to access an array with an invalid complex expression
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] a = [1, 2];
int b = a[1 - "horse"];
int c = a[!false]
end

View File

@ -0,0 +1,15 @@
# Attempting to access an array with a non-integer index.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] a = [1, 2];
int b = a["horse"]
end

View File

@ -0,0 +1,20 @@
# Indexing an array to get sub-arrays, but going too far.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] a = [1, 2];
int[] b = [3, 4];
int[][] ab = [a, b];
int[] sameAsA = ab[0];
int oops = sameAsA[0][1]
end

View File

@ -0,0 +1,15 @@
# Too much indexing!
# Thanks to David Pan
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] a = [1];
println a[1][2]
end

View File

@ -0,0 +1,14 @@
# Attempting to array-index an undefined identifier.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
char c = horse[2]
end

View File

@ -0,0 +1,19 @@
# Attempting to mix types in an array literal.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
char five() is
begin return '5' end
end
char f = call five();
int[] a = [1, f]
end

View File

@ -0,0 +1,15 @@
# Arrays should not be covariant
# Thanks to Nathaniel Burke
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
char[][] acs = [] ;
string[] bad = acs
end

View File

@ -0,0 +1,14 @@
# It shouldn't be possible to index strings (wrong side of the type relaxation)
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
string str = "hello world";
char x = str[0]
end

View File

@ -0,0 +1,16 @@
# Trying to assign the wrong array dimension to a variable, or non-matching pairs.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] a = [1, 2];
int[][] aa = [a, a];
int[] b = aa
end

View File

@ -0,0 +1,15 @@
# Accessing too many array indices.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] a = [1, 2];
int oops = a[1][2][3]
end

View File

@ -0,0 +1,14 @@
# Array was given wrong type.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] should_be_int_array = ['a', 'b', 'c']
end

View File

@ -0,0 +1,13 @@
# tries to exit using a character
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
exit 'a'
end

View File

@ -0,0 +1,14 @@
# exit with non-int - this should be an invalid program!
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
char x = 'f' ;
exit x
end

View File

@ -0,0 +1,14 @@
# trying to return from the main program
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
return 42 ;
println "should not get here"
end

View File

@ -0,0 +1,26 @@
# Returning from the main body is forbidden.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
while true do
return 3
done;
if true then
return 4
else
return 5
fi;
begin
return 6
end
end

View File

@ -0,0 +1,13 @@
# expresission type mismatch int->bool
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
bool b = 1 || 1
end

View File

@ -0,0 +1,13 @@
# expresission type mismatch bool->int
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int b = 15 + 6 || 19
end

View File

@ -0,0 +1,13 @@
# expresission type mismatch bool->int
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int x = true + false
end

View File

@ -0,0 +1,17 @@
# evaluating less-than on references
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(int,int) x = newpair(1,2) ;
println x;
pair(int,int) y = newpair(2,3) ;
println y;
println x < y
end

View File

@ -0,0 +1,13 @@
# expresission type mismatch bool->int
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int b = 1 + 2 + true + 4 + 5
end

View File

@ -0,0 +1,15 @@
# program performs boolean operations on arrays
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int[] x = [1,2];
int[] y = [3,4];
println x > y
end

View File

@ -0,0 +1,16 @@
# element access is not permitted for strings
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
string str = "hello world!" ;
println str ;
str[0] = 'H' ;
println str
end

View File

@ -0,0 +1,23 @@
# Two overloads with different pair parameters when called with a
# `null` directly leads to ambiguity.
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(pair(int, char) x) is
return 1
end
int f(pair(int, bool) x) is
return 2
end
int result = call f(null) ;
println result
end

View File

@ -0,0 +1,25 @@
# Overloads "fun" in a way that calling with a char array
# Could match multiple signatures (char[] or string)
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int fun(string x) is
int y = 1;
return y
end
int fun(char[] x) is
return len x
end
char[] text = ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'] ;
int result = call fun(text) ;
println result
end

View File

@ -0,0 +1,27 @@
# Two overloads with identical parameter lists but different return types
# leads to ambiguity.
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int x) is
return x
end
bool f(int x) is
return x > 0
end
bool result = call f(10) ;
if result then
println "Positive"
else
println "Non-positive"
fi
end

View File

@ -0,0 +1,24 @@
# Defines overloaded "f" for int and bool, but calls with a char.
# There's no matching overload that takes a char.
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int x) is
return x + 1
end
bool f(bool x) is
return !x
end
# Attempt to call f with a char argument -> no matching overload
int test = call f('z') ;
println test
end

View File

@ -0,0 +1,14 @@
# Calling an undefined identifier.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int result = call fib(5)
end

View File

@ -0,0 +1,17 @@
# functions cannot define the same argument twice
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int foo(int x, int x) is
return x
end
skip
end

View File

@ -0,0 +1,20 @@
# functions cannot access global variables
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
x = -1 ;
return 0
end
int x = 5 ;
int y = call f() ;
println x
end

View File

@ -0,0 +1,17 @@
# tries to assign to a function
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
return 3
end
f = 2
end

View File

@ -0,0 +1,18 @@
# function parameter misuse
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int x) is
bool b = x && true ;
return 0
end
int x = call f(0)
end

View File

@ -0,0 +1,18 @@
# function call type mismatch
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
return 0
end
bool b = call f() ;
println b
end

View File

@ -0,0 +1,23 @@
# function call type mismatch for both functions
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
return 0
end
bool f() is
return false
end
string s = call f() ;
println s
end

View File

@ -0,0 +1,18 @@
# function parameter type mismatch
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int x) is
return x
end
bool b = true ;
int x = call f(b)
end

View File

@ -0,0 +1,24 @@
# function parameter type mismatch for either function
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(pair(int, int) x) is
int y = fst x ;
return y
end
char f(int x) is
return chr x
end
bool b = true ;
int x = call f(b)
end

View File

@ -0,0 +1,18 @@
# function return type mismatch: int <- char
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
return 'c'
end
int x = call f() ;
println x
end

View File

@ -0,0 +1,23 @@
# function return type mismatch: int <- char even though there's an overloaded function with the correct return
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
char f() is
return 'c'
end
int f() is
return 'c'
end
int x = call f() ;
println x
end

View File

@ -0,0 +1,23 @@
# function return type mismatch: char <- int even though there's an overloaded function with the correct return
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
char f() is
return 1
end
int f() is
return 'a'
end
char x = call f() ;
println x
end

View File

@ -0,0 +1,18 @@
# function call with too many arguments
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int a, int b) is
return 0
end
int x = call f(1,2,3);
println x
end

View File

@ -0,0 +1,23 @@
# function call with more arguments than either function
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int a) is
return 0
end
int f(int a, int b) is
return 0
end
int x = call f(1,2,3);
println x
end

View File

@ -0,0 +1,20 @@
# attempted redefinition of function of the same signature
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
return 0
end
int f() is
return 1
end
int x = call f();
println x
end

View File

@ -0,0 +1,20 @@
# attempted redefinition of function of the same signature
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int a, string b, char c, bool d, pair(int, int) e, int[] f) is
return 0
end
int f(int a, string b, char c, bool d, pair(int, int) e, int[] f) is
return 1
end
skip
end

View File

@ -0,0 +1,19 @@
# function call with arguments swapped
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int a, bool b) is
return 0
end
int x = call f(true,1);
println x
end

View File

@ -0,0 +1,19 @@
# function call with too few arguments
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(int a, int b) is
return 0
end
int x = call f(1);
println x
end

View File

@ -0,0 +1,22 @@
# Trying to return something invalid from a branch.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
if true then
return horse
else
return 1
fi
end
int result = call f()
end

View File

@ -0,0 +1,27 @@
# Trying to return two different types from a function.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(bool b) is
if b then
return 'c';
while true do
return 'd'
done
else
skip
fi;
return 10
end
int a = call f(false)
end

View File

@ -0,0 +1,17 @@
# if condition type mismatch
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
if 15 + 6 then
skip
else
skip
fi
end

View File

@ -0,0 +1,25 @@
# a complete mess of function definition and use
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f(bool b) is
x = -1 ;
int x = 1;
bool x = 3;
return "done"
end
int x = 5 ;
int y = call f(x) ;
println y ;
println x
end

View File

@ -0,0 +1,19 @@
# boolean typo and if condition type mismatch
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
while tru do
if 15 + 6 then
skip
else
skip
fi
done
end

View File

@ -0,0 +1,16 @@
# long expression with multiple type mismatches
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int x = ((1 + true) * (2 - false)) || 17 ;
exit x
end

View File

@ -0,0 +1,20 @@
# variable names are case sensitive in all possible ways
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int number = 42 ;
println NUMBER ;
int INDEX = 0 ;
println index ;
int miXed = 3 ;
println MIxED
end

View File

@ -0,0 +1,15 @@
# multiple type mismatches: int <- bool, bool <- char, char <- int
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int i = true ;
bool b = 'a' ;
char c = 10
end

View File

@ -0,0 +1,44 @@
# Trying to obfuscate invalid returns with whiles.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int f() is
if true then
while true do
return 'a'
done
else
return 5
fi;
if false then
return 2
else
while false do
return 'b'
done
fi;
if true then
while true do
return -2
done
else
while false do
return -4
done
fi;
while false do
return !"horse"
done;
exit -1
end
int i = call f()
end

View File

@ -0,0 +1,13 @@
# newpair cannot be assigned to non-pair type
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int x = newpair(10, 20)
end

View File

@ -0,0 +1,15 @@
# Assignment is not legal when both sides types are not known
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(int, int) p = newpair(4, 5);
pair(pair, int) q = newpair(p, 6);
fst fst q = snd fst q
end

View File

@ -0,0 +1,16 @@
# call free on a non-pair type
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int x = 5;
println x;
free 9 ;
println x
end

View File

@ -0,0 +1,13 @@
# newpair must match the underlying type of the pair
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(char, bool) x = newpair(10, 20)
end

View File

@ -0,0 +1,15 @@
# Pairs should not be covariant
# Thanks to Nathaniel Burke
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(char[], char[]) pcs = null ;
pair(string, string) bad = pcs
end

View File

@ -0,0 +1,16 @@
# Trying to assign non-matching pairs to eachother.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(int, int) p1 = newpair(0, 0);
pair(char, char) p2 = newpair('a', 'a');
p2 = p1
end

View File

@ -0,0 +1,14 @@
# Reading is not legal when the type is not known
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(pair, int) p = null;
read fst fst p
end

View File

@ -0,0 +1,18 @@
# Trying to place an incorrect type into a parameterless pair.
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int c = 5;
pair(int, int) p = newpair(c, c);
pair(pair, int) oops = newpair(p, 0);
fst oops = c
end

View File

@ -0,0 +1,15 @@
# type mismatch: int <- char
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
int x = 4 ;
char y = 'a' ;
println x + y
end

View File

@ -0,0 +1,14 @@
# If we are reading into a pair projection, the element must be readable
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(bool, bool) p = null;
read fst p
end

View File

@ -0,0 +1,14 @@
# If we are reading into a pair projection, the element must be readable
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(bool, bool) p = null;
read snd p
end

View File

@ -0,0 +1,15 @@
# read into pair type
# Output:
# #semantic_error#
# Exit:
# 200
# Program:
begin
pair(int, int) p = newpair(1,2);
read p
end

Some files were not shown because too many files have changed in this diff Show More