Initial commit - copied from GitLab
This commit is contained in:
commit
4d6426c899
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
.metals/
|
||||
.bsp/
|
||||
.scala-build/
|
||||
.vscode/
|
||||
16
.gitlab-ci.yml
Normal file
16
.gitlab-ci.yml
Normal file
@ -0,0 +1,16 @@
|
||||
stages:
|
||||
- compile
|
||||
- test
|
||||
|
||||
compile:
|
||||
stage: compile
|
||||
script:
|
||||
- scala .
|
||||
|
||||
test:
|
||||
stage: test
|
||||
script:
|
||||
- scala test . # This will test everything. Will need to be modified as we add tests
|
||||
|
||||
# More stages will be added later to independently test different phases
|
||||
# To allow chosen tests to run as we add functionality
|
||||
21
Makefile
Normal file
21
Makefile
Normal file
@ -0,0 +1,21 @@
|
||||
# NOTE: PLEASE DON'T USE THIS MAKEFILE, IT IS FOR LABTS
|
||||
# it is *much* more efficient to use `scala compile .` trust me, I'm watching you.
|
||||
all:
|
||||
# the --server=false flag helps improve performance on LabTS by avoiding
|
||||
# downloading the build-server "bloop".
|
||||
# the --jvm system flag helps improve performance on LabTS by preventing
|
||||
# scala-cli from downloading a whole jdk distribution on the lab machine
|
||||
# the --force flag ensures that any existing built compiler is overwritten
|
||||
# the --power flag is needed as `package` is an experimental "power user" feature (NOTE: use this or --assembly if anything goes wrong)
|
||||
# scala --power package . --server=false --jvm system --force -o wacc-compiler
|
||||
# you can use --assembly to make it built a self-contained jar,
|
||||
# scala --power package . --server=false --jvm system --assembly --force -o wacc-compiler
|
||||
# you can use --native to make it build a native application (requiring Scala Native),
|
||||
# scala --power package . --server=false --jvm system --native --force -o wacc-compiler
|
||||
# or you can use --graalvm-jvm-id graalvm-java21 --native-image to build it using graalvm
|
||||
scala --power package . --server=false --jvm system --graalvm-jvm-id graalvm-java21 --native-image --force -o wacc-compiler
|
||||
|
||||
clean:
|
||||
scala clean . && rm -f wacc-compiler
|
||||
|
||||
.PHONY: all clean
|
||||
35
README.md
Normal file
35
README.md
Normal file
@ -0,0 +1,35 @@
|
||||
This is the provided git repository for the WACC compilers lab. You should work
|
||||
in this repository regularly committing and pushing your work back to GitLab.
|
||||
|
||||
# Provided files/directories
|
||||
|
||||
## src/main
|
||||
|
||||
The src/main directory is where you code for your compiler should go, and just
|
||||
contains a stub hello world file with a simple calculator inside.
|
||||
|
||||
## src/test
|
||||
The src/test directory is where you should put the code for your tests, which
|
||||
can be ran via `scala-cli test .`. The suggested framework is `scalatest`, the dependency
|
||||
for which has already been included.
|
||||
|
||||
## project.scala
|
||||
The `project.scala` is the definition of your project's build requirements. By default,
|
||||
this skeleton has added the latest stable versions of both `scalatest` and `parsley`
|
||||
to the build: you should check **regularly** to see if your `parsley` needs updating
|
||||
during the course of WACC!
|
||||
|
||||
## compile
|
||||
|
||||
The compile script can be edited to change the frontend interface to your WACC
|
||||
compiler. You are free to change the language used in this script, but do not
|
||||
change its name.
|
||||
|
||||
## Makefile
|
||||
|
||||
Your Makefile should be edited so that running 'make' in the root directory
|
||||
builds your WACC compiler. Currently running 'make' will call
|
||||
`scala --power package . --server=false --jvm system --graalvm-jvm-id graalvm-java21 --native-image --force -o wacc-compiler`, producing a file called
|
||||
`wacc-compiler`
|
||||
in the root directory of the project. If this doesn't work for whatever reason, there are a few
|
||||
different alternatives you can try in the makefile. **Do not use the makefile as you're working, it's for labts/CI!**
|
||||
9
compile
Executable file
9
compile
Executable file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
# Bash front-end for your compiler.
|
||||
# You are free to change the language used for this script,
|
||||
# but do *not* change its name.
|
||||
|
||||
# feel free to adjust to suit the specific internal flags of your compiler
|
||||
./wacc-compiler "$@"
|
||||
|
||||
exit $?
|
||||
26
project.scala
Normal file
26
project.scala
Normal file
@ -0,0 +1,26 @@
|
||||
//> using scala 3.6
|
||||
//> using platform jvm
|
||||
|
||||
// dependencies
|
||||
//> using dep com.github.j-mie6::parsley::5.0.0-M12
|
||||
//> using dep com.github.scopt::scopt::4.1.0
|
||||
//> using dep com.lihaoyi::os-lib::0.11.4
|
||||
//> using test.dep org.scalatest::scalatest::3.2.19
|
||||
|
||||
// these are all sensible defaults to catch annoying issues
|
||||
//> using options -deprecation -unchecked -feature
|
||||
//> using options -Wimplausible-patterns -Wunused:all
|
||||
//> using options -Yexplicit-nulls -Wsafe-init -Xkind-projector:underscores
|
||||
|
||||
// these will help ensure you have access to the latest parsley releases
|
||||
// even before they land on maven proper, or snapshot versions, if necessary.
|
||||
// just in case they cause problems, however, keep them turned off unless you
|
||||
// specifically need them.
|
||||
// using repositories sonatype-s01:releases
|
||||
// using repositories sonatype-s01:snapshots
|
||||
|
||||
// these are flags used by Scala native: if you aren't using scala-native, then they do nothing
|
||||
// lto-thin has decent linking times, and release-fast does not too much optimisation.
|
||||
// using nativeLto thin
|
||||
// using nativeGc commix
|
||||
// using nativeMode release-fast
|
||||
109
src/main/wacc/Main.scala
Normal file
109
src/main/wacc/Main.scala
Normal file
@ -0,0 +1,109 @@
|
||||
package wacc.frontend
|
||||
|
||||
import parsley._
|
||||
import scopt.OParser
|
||||
import wacc.backend.CodeGenerator
|
||||
import wacc.frontend.semantic._
|
||||
import wacc.frontend.semantic.environment._
|
||||
import wacc.frontend.syntax.ast.WProgram
|
||||
import wacc.frontend.syntax.ImportHandler
|
||||
import wacc.frontend.syntax.parser
|
||||
import wacc.frontend.syntax.SYNTAX_ERROR_CODE
|
||||
import wacc.extension.Peephole
|
||||
|
||||
val builder = OParser.builder[Config]
|
||||
val oParser = {
|
||||
import builder._
|
||||
OParser.sequence(
|
||||
opt[Unit]('p', "peephole")
|
||||
.action((_, c) => c.copy(peephole = true))
|
||||
.text("Enable peephole mode"),
|
||||
|
||||
opt[Unit]('i', "imports")
|
||||
.action((_, c) => c.copy(imports = true))
|
||||
.text("Enable imports mode"),
|
||||
|
||||
opt[String]('o', "output")
|
||||
.action((x, c) => c.copy(output = Some(x)))
|
||||
.text("Output file")
|
||||
)
|
||||
}
|
||||
|
||||
// This is the configuration for ar
|
||||
case class Config(peephole: Boolean = false,
|
||||
imports: Boolean = false,
|
||||
output: Option[String] = None,
|
||||
input: Option[String] = None)
|
||||
|
||||
// Example script: scala . -- "src/test/wacc/waccPrograms/valid/advanced/hashTable.wacc"
|
||||
def main(args: Array[String]): Unit = {
|
||||
|
||||
// First handle if the input file is present
|
||||
val inputFile = args.headOption match {
|
||||
case Some(file) => file
|
||||
case _ => "Not Found"
|
||||
}
|
||||
|
||||
// We handle the remaining flags
|
||||
val (peepholeFlag, importsFlag, outputFlag) = if (args.length < 2) {
|
||||
(false, false, None)
|
||||
} else {
|
||||
OParser.parse(oParser, args.tail, Config()) match {
|
||||
case Some(config) =>
|
||||
(config.peephole, config.imports, config.output)
|
||||
case _ =>
|
||||
(false, false, None)
|
||||
}
|
||||
}
|
||||
println("Hello WACC")
|
||||
|
||||
if (inputFile == "Not Found") {
|
||||
println("Please enter a valid file path")
|
||||
} else {
|
||||
// Compile the code
|
||||
compile(inputFile, peepholeFlag, importsFlag, outputFlag)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Option to change out how you want to run the Main.scala file
|
||||
* This won't print out the full program to terminal but will open and close the file safely
|
||||
* THIS IS THE FUNCTION TO CALL FOR LABTS
|
||||
*
|
||||
* @param args
|
||||
*/
|
||||
def compile(pathStr: String, peepholeFlag: Boolean, importsFlag: Boolean, outputFlag: Option[String]): Unit = {
|
||||
parser.parseFile(pathStr) match {
|
||||
case parsley.Failure(msg) => {
|
||||
println(s"\nSyntax Error $msg")
|
||||
sys.exit(SYNTAX_ERROR_CODE)
|
||||
}
|
||||
case parsley.Success(progFromSyntax: WProgram) => {
|
||||
val fileName = os.FilePath(pathStr).last
|
||||
val progToAnalyse = ImportHandler.apply().addImports(progFromSyntax, optimiseFlag = importsFlag)
|
||||
analyse(progToAnalyse) match {
|
||||
case Left(res) => {
|
||||
println("Completed Frontend ... Transitioning to backend...")
|
||||
val (prog, fEnv, mEnv) = res
|
||||
val codeGenerator = new CodeGenerator()
|
||||
val codeGen = codeGenerator.generate(prog, fEnv, mEnv)
|
||||
val optimisation = peepholeFlag match {
|
||||
case true => Peephole.apply().optimize(codeGen)
|
||||
case false => codeGen
|
||||
}
|
||||
val asmCode = optimisation.map(_.emit).mkString("\n")
|
||||
val outputFile = s"${fileName.replace(".wacc", ".s")}"
|
||||
val outputPath = os.pwd / os.SubPath(outputFlag.getOrElse("")) / outputFile
|
||||
os.makeDir.all(outputPath / os.up)
|
||||
os.write.over(outputPath, asmCode)
|
||||
println(s"Assembled code generated at ${outputPath / os.up}")
|
||||
}
|
||||
|
||||
case Right(errs) => {
|
||||
errs.foreach(err => println(s"\n${err.getMessage(fileName, pathStr)}"))
|
||||
sys.exit(SEMANTIC_ERROR_CODE)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
754
src/main/wacc/backend/codeGen.scala
Normal file
754
src/main/wacc/backend/codeGen.scala
Normal file
@ -0,0 +1,754 @@
|
||||
package wacc.backend
|
||||
|
||||
import wacc.frontend.syntax.ast._
|
||||
import wacc.frontend.semantic.environment._
|
||||
import scala.collection.mutable
|
||||
|
||||
class CodeGenerator {
|
||||
private given allocator: RegisterAllocator = new RegisterAllocator()
|
||||
private given labelGenerator: LabelGenerator = new LabelGenerator()
|
||||
|
||||
/**
|
||||
* Main entry for code generation
|
||||
* Returns the instructions as a list of instructions
|
||||
*/
|
||||
def generate(program: WProgram, globalFuncEnv: GlobalFuncEnv, globalMainEnv: GlobalMainEnv):
|
||||
List[Instruction] = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
given gMainEnv: GlobalMainEnv = globalMainEnv
|
||||
given gFuncEnv: GlobalFuncEnv = globalFuncEnv
|
||||
|
||||
// Generate code for each function declaration
|
||||
program.funcs.foreach(func => instructions ++= genFunc(func))
|
||||
|
||||
// Prepend the main body of the program to current instructions
|
||||
genMain(program.stats) ++=: instructions
|
||||
|
||||
// Prepend the preamble that has all of the string metadata
|
||||
preamble() ++=: instructions
|
||||
|
||||
// Append any standard labels found during generation of program
|
||||
instructions ++= labelGenerator.genWidgets()
|
||||
|
||||
// Add a new line to signify the end of the file
|
||||
instructions += EOF()
|
||||
|
||||
// Return the final list
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the preamble for the assembly file
|
||||
* Returns a list of the instructions
|
||||
*/
|
||||
def preamble(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Add the .data directive
|
||||
instructions += Directive("data")
|
||||
|
||||
// Add all of the string labels with their info
|
||||
for ((str, (newStr, label)) <- labelGenerator.getStringLiterals()) {
|
||||
instructions ++= List(
|
||||
Comment(s"length of $label", ""),
|
||||
Directive(s"word ${str.size}", " "),
|
||||
Label(label),
|
||||
Directive(s"asciz \"$newStr\"", " ")
|
||||
)
|
||||
}
|
||||
|
||||
// Add remaining preamble info
|
||||
instructions ++= List(
|
||||
Directive("align 4"),
|
||||
Directive("text"),
|
||||
Directive("global main")
|
||||
)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the function label
|
||||
* Save frame pointer and link register
|
||||
* Returns a list of the instructions
|
||||
*/
|
||||
def funcPrologue(funcLabel: String, _pushCalleeList: List[(String, PhysicalReg)]): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
instructions ++= List(
|
||||
// Function label
|
||||
Label(funcLabel),
|
||||
|
||||
// Stack setup (pushing frame pointer onto link register)
|
||||
Comment("push {fp, lr}"),
|
||||
STP(FP, LR, SP),
|
||||
|
||||
// Set the frame pointer to the current stack pointer now that fp has been saved
|
||||
Comment("Set fp to sp"),
|
||||
MoveOps(FP, SP)
|
||||
)
|
||||
|
||||
// Push all of the assigned registers in pushCalleeList using STP, where any odd number of registers
|
||||
// Would yield a storing of it and the zero register
|
||||
instructions ++= generatePushInstructions(_pushCalleeList, callingFunc = false)
|
||||
|
||||
// Compute and reserve space for the remaining local variables (Rounded to 16 Bytes to ensure stack alignment)
|
||||
val _remainingOffset = ((allocator.getNumVariables() - MAX_CALLEE_REGS + 1) & ~1) * EIGHT_BYTES
|
||||
if (_remainingOffset > 0) {
|
||||
val remainingOffset = Immediate(_remainingOffset)
|
||||
instructions += SUB(SP, SP, remainingOffset)
|
||||
}
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Function epilogue: pops all registers that were pushed frame pointer and return
|
||||
* Returns a list of the instructions
|
||||
*/
|
||||
def funcEpilogue(_popCalleeList: List[(String, PhysicalReg)]): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Add back the displacement made to sp if the offset is > 0
|
||||
val _remainingOffset = ((allocator.getNumVariables() - MAX_CALLEE_REGS + 1) & ~1) * EIGHT_BYTES
|
||||
if (_remainingOffset > 0) {
|
||||
val remainingOffset = Immediate(_remainingOffset)
|
||||
instructions += ADD(SP, SP, remainingOffset)
|
||||
}
|
||||
|
||||
// Add the comment for the registers being freed
|
||||
instructions ++= generatePopInstructions(_popCalleeList.toList, callingFunc = false)
|
||||
|
||||
// Pop the the frame pointer and link register
|
||||
instructions ++= List(
|
||||
Comment("pop {fp, lr}"),
|
||||
LDP(FP, LR, SP),
|
||||
RET(),
|
||||
EOF()
|
||||
)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/*
|
||||
Generates the push instructions for the registers in the pushList
|
||||
Returns a list of the instructions
|
||||
*/
|
||||
def generatePushInstructions(_pushList: List[(String, Allocated)], callingFunc: Boolean):
|
||||
List[Instruction] = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Tranform the _pushList to only contain x registers
|
||||
val pushList = _pushList.collect {
|
||||
case (id, reg: Reg) => (id, Register(reg.regNumber, X_REG))
|
||||
}
|
||||
|
||||
if (!pushList.isEmpty) {
|
||||
instructions += Comment(s"push {${_pushList.map(_._2.toOperand).mkString(", ")}}")
|
||||
}
|
||||
pushList.sliding(2, 2).foreach { pair =>
|
||||
// If there is a pair of registers, then push them together, otherwise, push the singleton
|
||||
if (pair.length == 2) {
|
||||
instructions += STP(pair(0)._2, pair(1)._2, SP)
|
||||
} else {
|
||||
instructions += STP(pair(0)._2, XZR, SP)
|
||||
}
|
||||
}
|
||||
|
||||
if (!pushList.isEmpty && callingFunc) {
|
||||
// Here we add the stack allocations in a fashion similar to stp for registers so that we can
|
||||
// change stack allocation easier later on
|
||||
val spillPushList = _pushList.collect { case (id, spill: Spill) => (id, spill)}
|
||||
|
||||
// Calculate the amount of stack space to allocate to the stack
|
||||
val _addedStackSpace = ((spillPushList.size + 1) & ~1) * EIGHT_BYTES
|
||||
val addedStackSpace = Immediate(_addedStackSpace)
|
||||
|
||||
if (_addedStackSpace != 0) {
|
||||
// Subtract that space from the stack pointer
|
||||
instructions += SUB(SP, SP, addedStackSpace)
|
||||
}
|
||||
|
||||
// Push each memory location onto the stack
|
||||
spillPushList.sliding(2, 2).zipWithIndex.foreach { case (pair, index) =>
|
||||
|
||||
// Index is incremented by 1
|
||||
val adjustedIndex = index + 1
|
||||
|
||||
// If there is a pair of spills, then push the rhs first and then the lhs
|
||||
if (pair.length == 2) {
|
||||
|
||||
// Store the two allocated values
|
||||
instructions ++= pushMemToStack(pair(1)._2, 2 * adjustedIndex - 1, _addedStackSpace)
|
||||
instructions ++= pushMemToStack(pair(0)._2, 2 * adjustedIndex, _addedStackSpace)
|
||||
|
||||
// Otherwise, push the zero register and then the other register
|
||||
} else {
|
||||
|
||||
// The offset to save memory to
|
||||
val saveOffset = Immediate(_addedStackSpace - (2 * adjustedIndex - 1) * EIGHT_BYTES)
|
||||
|
||||
// Store the two allocated values
|
||||
instructions += Store(XZR, SP, Some(saveOffset))
|
||||
instructions ++= pushMemToStack(pair(0)._2, (2 * adjustedIndex), _addedStackSpace)
|
||||
}
|
||||
}
|
||||
|
||||
// Order the registers in the same order of how they'd appear in the stack
|
||||
val stackList = (toStackList(_pushList)).reverse
|
||||
|
||||
instructions += Comment(s"Our stack has the order of (front of list shows what #0 is): [${stackList.map(_._2.toOperand).mkString(", ")}]")
|
||||
|
||||
// We change the identifier map to point to the stack position of the stored values if we are calling a function
|
||||
stackList.zipWithIndex.foreach { case ((id, alloc), index) =>
|
||||
|
||||
// We create an offset equal to the index of the register times 8 Bytes
|
||||
val offsetImm = Immediate(EIGHT_BYTES * index)
|
||||
val stackLocation = Spill(offsetImm, X16)
|
||||
|
||||
allocator.addAllocMapping(id, stackLocation)
|
||||
}
|
||||
instructions += Comment("Set up X16 as a temporary second base pointer for the caller saved things")
|
||||
instructions += MoveOps(X16, SP)
|
||||
}
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/*
|
||||
Generates the pop instructions for the registers in the popList
|
||||
Returns a list of the instructions
|
||||
*/
|
||||
|
||||
def generatePopInstructions(_popList: List[(String, Allocated)], callingFunc: Boolean):
|
||||
List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Tranform the _popList to a list only containing x registers
|
||||
val popList = _popList.collect{case (id, reg: Reg) => (id, Register(reg.regNumber, X_REG))}
|
||||
|
||||
// First remove the memory that was stored away if we were calling a function
|
||||
if (!popList.isEmpty && callingFunc) {
|
||||
|
||||
val spillPushList = _popList.collect { case (id, spill: Spill) => (id, spill)}
|
||||
|
||||
// Calculate the amount of stack space that was taken away
|
||||
val _addedStackSpace = ((spillPushList.size + 1) & ~1) * EIGHT_BYTES
|
||||
val addedStackSpace = Immediate(_addedStackSpace)
|
||||
|
||||
if (_addedStackSpace != 0) {
|
||||
// Add that space back to the stack pointer
|
||||
instructions += ADD(SP, SP, addedStackSpace)
|
||||
}
|
||||
}
|
||||
|
||||
if (popList.nonEmpty) {
|
||||
instructions += Comment(s"pop {${_popList.reverse.map(_._2.toOperand).mkString(", ")}}")
|
||||
|
||||
val rest = if (popList.length % 2 == 1) {
|
||||
instructions += LDP(popList.head._2, XZR, SP) // Pop first register if odd
|
||||
popList.tail // Process the remaining elements
|
||||
} else popList
|
||||
|
||||
rest.sliding(2, 2).foreach { pair =>
|
||||
|
||||
// As they were pushed in order, to maintain the order, pop in reverse taking right then left
|
||||
instructions += LDP(pair(1)._2, pair(0)._2, SP)
|
||||
}
|
||||
}
|
||||
|
||||
popList.foreach(pair => allocator.freeCallee(pair._2))
|
||||
|
||||
// We revert the identifier map to point back to the original registers if we called a function
|
||||
// Note: we use _popList instead of popList as that contains the correct mapping for all allocated values
|
||||
if (callingFunc) {
|
||||
_popList.foreach { case (id, alloc) =>
|
||||
allocator.addAllocMapping(id, alloc)
|
||||
}
|
||||
}
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
def toStackList(list: List[(String, Allocated)]): List[(String, Allocated)] = {
|
||||
// Order the registers in the same order of how they'd appear in the stack (but reversed)
|
||||
list.sliding(2, 2).flatMap { pair =>
|
||||
if (pair.length == 2) {
|
||||
List(pair(1), pair(0))
|
||||
} else {
|
||||
List(("", XZR), pair(0))
|
||||
}
|
||||
}.toList
|
||||
}
|
||||
|
||||
def pushMemToStack(spill: Spill, num: Int, totalStackSpace: BigInt): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// The offset to save memory to
|
||||
val saveOffset = Immediate(totalStackSpace - num * EIGHT_BYTES)
|
||||
|
||||
// We must load the contents of the location to x17 and then store that into the stack pointer at the right place
|
||||
instructions += Load(X17, spill)
|
||||
instructions += Store(X17, SP, Some(saveOffset))
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates all of the instructions for the main body block
|
||||
* Returns a list of the instructions
|
||||
*/
|
||||
def genMain(stmts: List[Stmt])
|
||||
(using gMainEnv: GlobalMainEnv, gFuncEnv: GlobalFuncEnv): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Generate code for each statement in the main block
|
||||
given pushCalleeList: mutable.ListBuffer[(String, PhysicalReg)] = mutable.ListBuffer.empty
|
||||
stmts.foreach(stmt => instructions ++= genStmt(stmt)(using pushCalleeList, gMainEnv, gFuncEnv, List.empty))
|
||||
|
||||
// Prepend the prologue with the pushCalleeList
|
||||
funcPrologue("main", pushCalleeList.toList) ++=: instructions
|
||||
|
||||
// Store 0 in return reg as the default exit code if program is fine
|
||||
instructions += MoveOps(X0, Imm0)
|
||||
|
||||
// Append the epilogue with the popCalleeList(pushCalleeList reversed)
|
||||
instructions ++= funcEpilogue(pushCalleeList.reverse.toList)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates all of the instructions for the functions
|
||||
* Returns a list of the instructions
|
||||
*/
|
||||
def genFunc(func: Func)(using gMainEnv: GlobalMainEnv, gFuncEnv: GlobalFuncEnv):
|
||||
List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// The id of the function used to query the global function environment
|
||||
val funcKey = func.id.id
|
||||
|
||||
// Replace # with an underscore for funcLabel
|
||||
val funcLabel = func.id.id.replaceAll("#", "_")
|
||||
|
||||
// Every function has a different push list to see stored variables
|
||||
given pushCalleeList: mutable.ListBuffer[(String, PhysicalReg)] = mutable.ListBuffer.empty
|
||||
|
||||
// Save the state of the allocator before setting up function
|
||||
val prevState = allocator.saveState()
|
||||
|
||||
// Retrieve the parameters of the function
|
||||
val funcParams = gFuncEnv.lookup(funcKey).getOrElse(
|
||||
throw new RuntimeException(s"Function $funcKey not found in the environment")
|
||||
).params
|
||||
|
||||
given funcRegList: mutable.ListBuffer[(String, Allocated)] = mutable.ListBuffer.empty
|
||||
funcParams.zipWithIndex.foreach{case (param, index) =>
|
||||
// Add the parameter to the global environment
|
||||
gMainEnv.add(param)
|
||||
|
||||
// Add the parameter to the function register list
|
||||
val bitMode = param._type match {
|
||||
case (IntType | CharType | BoolType) => W_REG
|
||||
case _ => X_REG
|
||||
}
|
||||
|
||||
val paramAlloc = index match {
|
||||
|
||||
// If the index is between 0 and 7, then create a parameter register for it
|
||||
case paramReg if paramReg < MAX_PARAM_REGS => Register(s"$paramReg", bitMode)
|
||||
|
||||
// Otherwise, create a space in memory for it
|
||||
case nonReg =>
|
||||
|
||||
// Calculate the stack offset for the non register parameter
|
||||
val offset = (nonReg - MAX_PARAM_REGS) * EIGHT_BYTES
|
||||
|
||||
// Adjust the offset by 16 Bytes to account for fp lr push onto stack
|
||||
val adjustedOffset = BigInt(offset) + SIXTEEN_BYTES
|
||||
val stackOffset = Immediate(adjustedOffset)
|
||||
|
||||
// Create a stack allocation with FP as the base pointer as this is how to access the
|
||||
// variable in the function
|
||||
Spill(stackOffset, FP)
|
||||
}
|
||||
|
||||
// Adding the param's name with its arg register to the funcRegList
|
||||
val listEntryFunc = (param.uniqueName, paramAlloc)
|
||||
funcRegList += listEntryFunc
|
||||
|
||||
// Also add the allocation to the callee push list
|
||||
val listEntryCallee = (param.uniqueName, paramAlloc)
|
||||
listEntryCallee match {
|
||||
case (identifier, physicalReg: PhysicalReg) =>
|
||||
val entry = (identifier, physicalReg)
|
||||
pushCalleeList += entry
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
// Create the linkage between an id and its paramAlloc
|
||||
allocator.addAllocMapping(param.uniqueName, paramAlloc)
|
||||
}
|
||||
|
||||
// Generate code for each statement in the function block
|
||||
func.body.foreach(stmt => instructions ++=
|
||||
genStmt(stmt)(using pushCalleeList, gMainEnv, gFuncEnv, funcRegList.toList))
|
||||
|
||||
// Prepend the prologue with the pushCalleeList
|
||||
funcPrologue(funcLabel, pushCalleeList.toList) ++=: instructions
|
||||
|
||||
// Replace all instances of a comment called "EPILOGUE" with the epilogue of the function
|
||||
// At this point, we have all callee registers known to us
|
||||
val finalInstructions = instructions.flatMap {
|
||||
case Comment("EPILOGUE", "") => funcEpilogue(pushCalleeList.reverse.toList)
|
||||
case other => List(other)
|
||||
}
|
||||
|
||||
// Restore the state to allow for callees to run as intended
|
||||
allocator.restoreState(prevState)
|
||||
|
||||
finalInstructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates code for a statement
|
||||
* Returns a list of instructions
|
||||
*/
|
||||
def genStmt(stmt: Stmt)
|
||||
(using pushCalleeList: mutable.ListBuffer[(String, PhysicalReg)], gMainEnv: GlobalMainEnv,
|
||||
gFuncEnv: GlobalFuncEnv, funcRegList: List[(String, Allocated)]): List[Instruction] = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
stmt match {
|
||||
case Skip(_) =>
|
||||
|
||||
// We don't have any instructions here
|
||||
instructions += Comment("SKIP")
|
||||
instructions.toList
|
||||
|
||||
case Declare(_, t, ident, rhs) =>
|
||||
instructions += Comment("DECLARE BEGINS HERE")
|
||||
|
||||
// Evaluate the right-hand side
|
||||
val (rhsInstr, rhsAlloc) = genRhs(rhs, this)
|
||||
instructions ++= rhsInstr
|
||||
|
||||
// Allocate a callee allocation for this ident
|
||||
val calleeAlloc = rhsType(rhs)(using gMainEnv, gFuncEnv) match {
|
||||
|
||||
// Ints, Chars, and Bools all use W_REG register bit-modes
|
||||
case (IntType | CharType | BoolType) => allocator.allocateCallee(W_REG)
|
||||
case _ => allocator.allocateCallee(X_REG)
|
||||
}
|
||||
|
||||
// Add the allocation to pushCalleeList (Will create STP instructions to store to the stack)
|
||||
val listEntryCallee = (ident.id, calleeAlloc)
|
||||
listEntryCallee match {
|
||||
case (identity, physicalReg: PhysicalReg) =>
|
||||
val entry = (identity, physicalReg)
|
||||
pushCalleeList += entry
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
// Store the value of the rhs to the lhs
|
||||
// Check on the whether the operand is on the stack or a register
|
||||
instructions += Move(calleeAlloc, rhsAlloc)
|
||||
|
||||
// Add the allocation to the variable's map of vars to maps
|
||||
allocator.addAllocMapping(ident.id, calleeAlloc)
|
||||
|
||||
// Free the temp register
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
|
||||
rhs match {
|
||||
case Call(_, id, args) =>
|
||||
|
||||
// If the rhs is a function call, free all of the callee registers
|
||||
instructions ++= generatePopInstructions(funcRegList.reverse, callingFunc = true)
|
||||
case _ =>
|
||||
}
|
||||
|
||||
instructions.toList
|
||||
|
||||
case Assign(_, lhs, rhs) =>
|
||||
instructions += Comment("ASSIGN BEGINS HERE")
|
||||
|
||||
// Create w/x8 register using allocTemp as this is at the front of the temp pool (only for ArrayElems)
|
||||
val bufferReg8 = lhsType(lhs)(using gMainEnv) match {
|
||||
case (IntType | CharType | BoolType) => allocator.allocateTemp(W_REG)
|
||||
case _ => allocator.allocateTemp(X_REG)
|
||||
}
|
||||
|
||||
// Only free the lhs if it's not an arrayElem
|
||||
lhs match {
|
||||
case ArrayElem(_, _, _) => ()
|
||||
case _ => allocator.freeTemp(bufferReg8)
|
||||
}
|
||||
|
||||
val (rhsInstr, rhsAlloc) = genRhs(rhs, this)
|
||||
instructions ++= rhsInstr
|
||||
|
||||
val newRhsAlloc = rhs match {
|
||||
case Call(_, id, args) =>
|
||||
|
||||
// Set up the function scrath and return registers according to the lhs type
|
||||
val (functionScratchReg, returnReg) = lhsType(lhs)(using gMainEnv) match {
|
||||
case (IntType | CharType | BoolType) => (W16, W0)
|
||||
case _ => (X16, X0)
|
||||
}
|
||||
|
||||
// Move the output to a function scratch register
|
||||
instructions += MOV(functionScratchReg, returnReg)
|
||||
|
||||
// Free all of the callee registers
|
||||
instructions ++= generatePopInstructions(funcRegList.reverse, callingFunc = true)
|
||||
|
||||
functionScratchReg
|
||||
|
||||
case _ => rhsAlloc
|
||||
}
|
||||
|
||||
// Get the lhs to move the rhs into
|
||||
val (lhsInstr, lhsAlloc) = genLhsForUsing(lhs, instructions)
|
||||
instructions ++= lhsInstr
|
||||
|
||||
// Make the lhsMove now
|
||||
// Performs the moving of the rhs to the lhs which handles edge cases (ArrayElem, PairElem)
|
||||
instructions ++= lhsMove(lhs, lhsAlloc, newRhsAlloc, bufferReg8)
|
||||
|
||||
// Free all temp allocations
|
||||
allocator.freeTemp(newRhsAlloc)
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
allocator.freeTemp(lhsAlloc)
|
||||
allocator.freeTemp(bufferReg8)
|
||||
|
||||
instructions.toList
|
||||
|
||||
case pStmt: PrintBase =>
|
||||
instructions += Comment("PRINT BEGINS HERE")
|
||||
|
||||
// Push registers onto the stack to save their value
|
||||
instructions ++= generatePushInstructions(funcRegList, true)
|
||||
|
||||
val (exprInstr, exprAlloc) = genExpr(pStmt.expr)
|
||||
|
||||
// The bit mode of the return register (x/w 0) will be determined by the expression's type
|
||||
val returnReg = exprType(pStmt.expr)(using gMainEnv) match {
|
||||
case (IntType | CharType | BoolType) => W0
|
||||
case _ => X0
|
||||
}
|
||||
instructions ++= exprInstr
|
||||
instructions += MoveOps(returnReg, exprAlloc)
|
||||
|
||||
// Check for the type of the expression here
|
||||
exprType(pStmt.expr)(using gMainEnv) match {
|
||||
case BoolType =>
|
||||
instructions += BL(_printb)
|
||||
case CharType =>
|
||||
instructions += BL(_printc)
|
||||
case IntType =>
|
||||
instructions += BL(_printi)
|
||||
|
||||
// Can print Char arrays as a string
|
||||
case (StrType | ArrayType(CharType)) =>
|
||||
instructions += BL(_prints)
|
||||
|
||||
// All of the other types require _printp
|
||||
case _ =>
|
||||
instructions += BL(_printp)
|
||||
}
|
||||
|
||||
// Add the branch to _println for a Println statement
|
||||
if (pStmt.isInstanceOf[Println]) {
|
||||
instructions += BL(_println)
|
||||
}
|
||||
|
||||
// Free allocation
|
||||
allocator.freeTemp(exprAlloc)
|
||||
|
||||
// Pop off the saved registers from the stack
|
||||
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
|
||||
instructions.toList
|
||||
|
||||
case Read(_, lhs) =>
|
||||
instructions += Comment("READ BEGINS HERE")
|
||||
|
||||
// Save argument registers
|
||||
instructions ++= generatePushInstructions(funcRegList, callingFunc = true)
|
||||
|
||||
// Create w8 register using allocTemp as this is at the front of the temp pool
|
||||
val bufferReg8 = allocator.allocateTemp(W_REG)
|
||||
val (lhsInstr, lhsAlloc) = genLhs(lhs)
|
||||
|
||||
// Ints and Chars will always use W_REG bit-mode registers
|
||||
instructions ++= lhsInstr
|
||||
lhs match {
|
||||
case _: PairElem =>
|
||||
instructions ++= List(
|
||||
Compare(lhsAlloc, Imm0),
|
||||
BCond(EQ, _errNull),
|
||||
Load(X0, lhsAlloc)
|
||||
)
|
||||
case _ => instructions += Move(W0, lhsAlloc)
|
||||
}
|
||||
|
||||
// Check for the type of the lhs here
|
||||
lhsType(lhs)(using gMainEnv) match {
|
||||
case CharType =>
|
||||
instructions += BL(_readc)
|
||||
|
||||
case IntType =>
|
||||
instructions += BL(_readi)
|
||||
|
||||
case _ =>
|
||||
instructions += Comment("UNKNOWN TYPE IN READ")
|
||||
}
|
||||
|
||||
// Finish register manipulation
|
||||
// Depending on the type of the lhs, invoke the correct next move
|
||||
// Ints and Chars will always use W_REG bit-mode registers
|
||||
instructions += MoveOps(W16, W0)
|
||||
|
||||
// Free the lhsAlloc before reloading lhs
|
||||
allocator.freeTemp(lhsAlloc)
|
||||
|
||||
// Reload the lhs so that we get the correct address
|
||||
val (reloadInstrs, reloadLhs) = genLhsForUsing(lhs, instructions)
|
||||
instructions ++= reloadInstrs
|
||||
instructions ++= lhsMove(lhs, reloadLhs, W16, bufferReg8)
|
||||
|
||||
allocator.freeTemp(bufferReg8)
|
||||
allocator.freeTemp(reloadLhs)
|
||||
|
||||
// Pop off argument registers
|
||||
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
|
||||
instructions.toList
|
||||
|
||||
case Exit(_, expr) =>
|
||||
instructions += Comment("EXIT BEGINS HERE")
|
||||
|
||||
// Push registers onto the stack to save their value
|
||||
instructions ++= generatePushInstructions(funcRegList, true)
|
||||
|
||||
// For an exit statement, evaluate the expression and call the exit routine
|
||||
val (exprInstr, exprAlloc) = genExpr(expr)
|
||||
instructions ++= exprInstr
|
||||
|
||||
// The expression is always an int so we use the W_REG bit-mode
|
||||
instructions += MoveOps(W0, exprAlloc)
|
||||
instructions += BL(exit)
|
||||
|
||||
// Pop off argument registers
|
||||
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
|
||||
|
||||
allocator.freeTemp(exprAlloc)
|
||||
instructions.toList
|
||||
|
||||
case Return(_, expr) =>
|
||||
instructions += Comment("RETURN BEGINS HERE")
|
||||
val (exprInstr, exprAlloc) = genExpr(expr)
|
||||
|
||||
// The bit mode of the return register (x/w 0) will be determined by the expression's type
|
||||
val returnReg = exprType(expr)(using gMainEnv) match {
|
||||
case (IntType | CharType | BoolType) => W0
|
||||
case _ => X0
|
||||
}
|
||||
instructions ++= exprInstr
|
||||
instructions += MoveOps(returnReg, exprAlloc)
|
||||
|
||||
// We mark this section as needing to have an epilogue
|
||||
// It is filled in later in the genFunc method
|
||||
instructions += Comment("EPILOGUE", "")
|
||||
|
||||
allocator.freeTemp(exprAlloc)
|
||||
instructions.toList
|
||||
|
||||
case If(_, cond, thenStats, elseStats) =>
|
||||
instructions += Comment("IF BEGINS HERE")
|
||||
|
||||
// Generate fresh labels for then and end case. Else case would go right after label
|
||||
val thenLabel = labelGenerator.freshLabel()
|
||||
val endLabel = labelGenerator.freshLabel()
|
||||
|
||||
// Generate code for the condition, ending with a Compare instruction
|
||||
val (condInstr, branchFlag) = genCond(cond)
|
||||
instructions ++= condInstr
|
||||
|
||||
// Branch to the then block if the condition is true
|
||||
instructions += BCond(branchFlag, thenLabel)
|
||||
|
||||
// Continue onto generating code for the else block
|
||||
elseStats.foreach(stmt => instructions ++= genStmt(stmt))
|
||||
|
||||
// Branch to the end label after finishing else block
|
||||
instructions += B(endLabel)
|
||||
|
||||
// Label and instructions for then
|
||||
instructions += Label(thenLabel)
|
||||
thenStats.foreach(stmt => instructions ++= genStmt(stmt))
|
||||
|
||||
// End of the if statement
|
||||
instructions += Label(endLabel)
|
||||
|
||||
instructions.toList
|
||||
|
||||
case While(_, cond, body) =>
|
||||
instructions += Comment("WHILE BEGINS HERE")
|
||||
|
||||
// Generate fresh labels for the loop
|
||||
val condLabel = labelGenerator.freshLabel()
|
||||
val bodyLabel = labelGenerator.freshLabel()
|
||||
|
||||
// Branch to the condition check
|
||||
instructions += B(condLabel)
|
||||
|
||||
// Generate code for the loop body
|
||||
instructions += Label(bodyLabel)
|
||||
body.foreach(stmt => instructions ++= genStmt(stmt))
|
||||
|
||||
// After the body, check the condition again
|
||||
instructions += B(condLabel)
|
||||
|
||||
// Generate code for the condition check
|
||||
instructions += Label(condLabel)
|
||||
val (condInstr, condFlag) = genCond(cond)
|
||||
instructions ++= condInstr
|
||||
instructions += BCond(condFlag, bodyLabel)
|
||||
|
||||
// End of the loop just continues
|
||||
|
||||
instructions.toList
|
||||
|
||||
// Generates each statement in the block
|
||||
case Block(_, stmts) => stmts.flatMap(genStmt)
|
||||
|
||||
case Free(_, expr: Rhs) =>
|
||||
val (exprInstr, exprAlloc) = genRhs(expr, this)
|
||||
|
||||
// Save argument registers
|
||||
instructions ++= generatePushInstructions(funcRegList, callingFunc = true)
|
||||
instructions ++= exprInstr
|
||||
|
||||
exprType(expr)(using gMainEnv) match {
|
||||
case _: ArrayType =>
|
||||
instructions ++= List(
|
||||
Comment("array pointers are shifted forward by 4 bytes, " +
|
||||
"so correct it back to original pointer before free"),
|
||||
SUB(exprAlloc, exprAlloc, Imm4),
|
||||
MoveOps(X0, exprAlloc),
|
||||
BL(free)
|
||||
)
|
||||
case _: (PairType | PairPlaceholder.type | Any) =>
|
||||
instructions += MoveOps(X0, exprAlloc)
|
||||
instructions += BL(_freepair)
|
||||
}
|
||||
|
||||
// Pop off argument registers
|
||||
instructions ++= generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
|
||||
|
||||
allocator.freeTemp(exprAlloc)
|
||||
instructions.toList
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
449
src/main/wacc/backend/expressionGen.scala
Normal file
449
src/main/wacc/backend/expressionGen.scala
Normal file
@ -0,0 +1,449 @@
|
||||
package wacc.backend
|
||||
|
||||
import wacc.frontend.syntax.ast._
|
||||
import wacc.frontend.semantic.environment._
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* Generates code for an expression
|
||||
* Returns a tuple: (instructions, allocated operand holding the result)
|
||||
*/
|
||||
def genExpr(expr: Expr)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], PhysicalReg) = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// The bit-mode of the temp register depends on the type of the expression
|
||||
val tempAlloc = exprType(expr) match {
|
||||
case (IntType | CharType | BoolType) => allocator.allocateTemp(W_REG)
|
||||
case _ => allocator.allocateTemp(X_REG)
|
||||
}
|
||||
|
||||
expr match {
|
||||
case IntLiter(_, value) =>
|
||||
val valueImm = Immediate(value)
|
||||
instructions += MoveOps(tempAlloc, valueImm)
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case BoolLiter(_, _value) =>
|
||||
val value = if (_value) 1 else 0
|
||||
val valueImm = Immediate(value)
|
||||
instructions += MoveOps(tempAlloc, valueImm)
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case CharLiter(_, value) =>
|
||||
val valueImm = Immediate(value.toInt)
|
||||
instructions += MoveOps(tempAlloc, valueImm)
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case StrLiter(_, value) =>
|
||||
val strLabel = labelGenerator.freshLabel(value)
|
||||
instructions += ADRP(tempAlloc, strLabel)
|
||||
instructions += ADD(tempAlloc, tempAlloc, LabelOp(strLabel))
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case PairLiter(_) =>
|
||||
instructions += MoveOps(tempAlloc, Imm0)
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case Ident(_, id) =>
|
||||
val _idAlloc = allocator.getAllocMapping(id)
|
||||
|
||||
// Make the bit-mode of the idAlloc the same as the bitmode of the tempAlloc
|
||||
val idAlloc = _idAlloc match {
|
||||
case reg: PhysicalReg => tempAlloc match {
|
||||
case PhysicalReg(_, X_REG, _) => PhysicalReg(reg.regNumber, X_REG, reg.pool)
|
||||
case PhysicalReg(_, other, _) => PhysicalReg(reg.regNumber, W_REG, reg.pool)
|
||||
}
|
||||
case spill => spill
|
||||
}
|
||||
instructions += Move(tempAlloc, idAlloc)
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case Add(lhs, rhs) =>
|
||||
allocator.freeTemp(tempAlloc)
|
||||
genArithOp(lhs, rhs, ADDS.apply)
|
||||
case Sub(lhs, rhs) =>
|
||||
allocator.freeTemp(tempAlloc)
|
||||
genArithOp(lhs, rhs, SUBS.apply)
|
||||
case Mul(lhs, rhs) =>
|
||||
|
||||
// Restore the register to make efficient use of registers
|
||||
allocator.restore()
|
||||
val (lhsInstr, lhsAlloc) = genExpr(lhs)
|
||||
val (rhsInstr, rhsAlloc) = genExpr(rhs)
|
||||
|
||||
// Need an x bit-mode register for detecting overflow errors
|
||||
val xLhs = PhysicalReg(lhsAlloc.regNumber, X_REG, lhsAlloc.pool)
|
||||
|
||||
instructions ++= lhsInstr
|
||||
instructions ++= rhsInstr
|
||||
instructions ++= List(
|
||||
SMULL(xLhs, lhsAlloc, rhsAlloc),
|
||||
Compare(xLhs, lhsAlloc ,Some(SXTW())),
|
||||
BCond(NE, _errOverflow)
|
||||
)
|
||||
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
(instructions.toList, lhsAlloc)
|
||||
|
||||
// Handles Division and Modulo operations
|
||||
case divOp: DivOp =>
|
||||
val (lhs, rhs) = (divOp.lhs, divOp.rhs)
|
||||
allocator.restore()
|
||||
val (lhsInstr, lhsAlloc) = genExpr(lhs)
|
||||
val (rhsInstr, rhsAlloc) = genExpr(rhs)
|
||||
instructions ++= lhsInstr
|
||||
instructions ++= rhsInstr
|
||||
instructions ++= List(
|
||||
Compare(rhsAlloc, Imm0),
|
||||
BCond(EQ, _errDivZero),
|
||||
)
|
||||
divOp match {
|
||||
case _: Div => instructions += SDIV(lhsAlloc, lhsAlloc, rhsAlloc)
|
||||
case _: Mod =>
|
||||
instructions += SDIV(W17, lhsAlloc, rhsAlloc)
|
||||
instructions += MSUB(lhsAlloc, W17, rhsAlloc, lhsAlloc)
|
||||
}
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
(instructions.toList, lhsAlloc)
|
||||
|
||||
case And(lhs, rhs) => genBoolOp(lhs, rhs, NE)
|
||||
case Or(lhs, rhs) => genBoolOp(lhs, rhs, EQ)
|
||||
case compOp: CompOp => genComparison(compOp.lhs, compOp.rhs, compOp)
|
||||
case chr: Chr =>
|
||||
val (chrInstr, chrAlloc) = genUnaryOp(chr.expr, chr)
|
||||
val xChr = chrAlloc match {
|
||||
case PhysicalReg(regNumber, _ , pool) => PhysicalReg(regNumber, X_REG, pool)
|
||||
}
|
||||
instructions ++= chrInstr
|
||||
val testVal = -128
|
||||
val tstImm = Immediate(testVal)
|
||||
instructions ++= List(
|
||||
TST(chrAlloc, tstImm),
|
||||
CSEL(X1, xChr, X1, NE),
|
||||
BCond(NE, _errBadChar)
|
||||
)
|
||||
(instructions.toList, chrAlloc)
|
||||
|
||||
case neg: Neg =>
|
||||
val (negInstr, negAlloc) = genUnaryOp(neg.expr, neg)
|
||||
instructions ++= negInstr
|
||||
instructions += BCond(VS, _errOverflow)
|
||||
(instructions.toList, negAlloc)
|
||||
|
||||
case unOp: UnOp => genUnaryOp(unOp.expr, unOp)
|
||||
|
||||
case IfExpr(pos, cond, thenBranch, elseBranch) =>
|
||||
val thenLabel = labelGenerator.freshLabel()
|
||||
val endLabel = labelGenerator.freshLabel()
|
||||
val (condInstr, branchFlag) = genCond(cond)
|
||||
|
||||
instructions ++= condInstr
|
||||
|
||||
// Branch to the then block if the condition is true
|
||||
instructions += BCond(branchFlag, thenLabel)
|
||||
val resultReg = exprType(thenBranch) match {
|
||||
case (IntType | CharType | BoolType) => allocator.allocateTemp(W_REG)
|
||||
case _ => allocator.allocateTemp(X_REG)
|
||||
}
|
||||
val (elseInstr, elseAlloc) = genExpr(elseBranch)
|
||||
instructions ++= elseInstr
|
||||
instructions += Move(resultReg, elseAlloc)
|
||||
instructions += B(endLabel)
|
||||
|
||||
instructions += Label(thenLabel)
|
||||
val (thenInstr, thenAlloc) = genExpr(thenBranch)
|
||||
instructions ++= thenInstr
|
||||
instructions += Move(resultReg, thenAlloc)
|
||||
instructions += B(endLabel)
|
||||
|
||||
instructions += Label(endLabel)
|
||||
|
||||
(instructions.toList, resultReg)
|
||||
|
||||
case PairElem(_, selector, lhs: Expr) =>
|
||||
allocator.restore()
|
||||
val (exprInstr, exprAlloc) = genExpr(lhs)
|
||||
val xExpr = PhysicalReg(exprAlloc.regNumber, X_REG, exprAlloc.pool)
|
||||
instructions ++= exprInstr
|
||||
instructions += Compare(exprAlloc, Imm0)
|
||||
instructions += BCond(EQ, _errNull)
|
||||
selector match {
|
||||
case Fst => instructions += LDR(xExpr, exprAlloc)
|
||||
case Snd => instructions += LDR(xExpr, exprAlloc, Some(Imm8))
|
||||
}
|
||||
|
||||
// We use knowledge of the bitmode from tempAlloc to get the right bitmode for return
|
||||
val resultReg = PhysicalReg(xExpr.regNumber, tempAlloc.bitMode, xExpr.pool)
|
||||
(instructions.toList, resultReg)
|
||||
|
||||
case ArrayElem(_, id, indices) =>
|
||||
// Get the array type of the identifier we are accessing
|
||||
var curType = exprType(id) match {
|
||||
case ArrayType(innerType) => innerType
|
||||
case _type => _type
|
||||
}
|
||||
|
||||
val xTemp = PhysicalReg(tempAlloc.regNumber, X_REG, tempAlloc.pool)
|
||||
|
||||
for ((indexExpr, i) <- indices.zipWithIndex) {
|
||||
instructions += Comment(s"curType = $curType, indexExpr = $indexExpr, i = $i")
|
||||
|
||||
// If i != last index, then we have to store our previous value (so must be an x bit-value)
|
||||
if (i != 0) {
|
||||
instructions += Comment(s"push {${xTemp}}")
|
||||
instructions += STP(xTemp, XZR, SP)
|
||||
}
|
||||
|
||||
// Generate code for the index expression
|
||||
val (indexInstr, indexAlloc) = genExpr(indexExpr)
|
||||
instructions ++= indexInstr
|
||||
instructions += MoveOps(W17, indexAlloc)
|
||||
|
||||
allocator.freeTemp(indexAlloc)
|
||||
|
||||
// If we're not on the first iteration, pop x7 back
|
||||
if (i != 0) {
|
||||
instructions += Comment("pop {x7}")
|
||||
instructions += LDP(X7, XZR, SP)
|
||||
} else {
|
||||
val (idInstr, idAlloc) = genExpr(id)
|
||||
instructions ++= idInstr
|
||||
instructions += MoveOps(X7, idAlloc)
|
||||
|
||||
allocator.freeTemp(idAlloc)
|
||||
}
|
||||
|
||||
// Call the appropriate array-load function
|
||||
instructions += BL(_arrLoad(s"${sizeof(curType)}"))
|
||||
|
||||
curType match {
|
||||
case ArrayType(_) => instructions += MoveOps(xTemp, X7)
|
||||
case _ =>
|
||||
val reg7 = tempAlloc match {
|
||||
case reg: Reg => reg.bitMode match {
|
||||
case W_REG => W7
|
||||
case X_REG => X7
|
||||
|
||||
// This case shouldn't happen (Maybe call it a random register)
|
||||
case _ => Register("Should know type of reg7", ' ')
|
||||
}
|
||||
}
|
||||
instructions += MoveOps(tempAlloc, reg7)
|
||||
}
|
||||
|
||||
curType match {
|
||||
case ArrayType(innerType) => curType = innerType
|
||||
case _ => ()
|
||||
}
|
||||
}
|
||||
|
||||
(instructions.toList, tempAlloc)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates assembly for a binary operation
|
||||
*/
|
||||
def genBinaryOp(lhs: Expr, rhs: Expr, op: (Reg, Operand, Operand) => Instruction)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], PhysicalReg) = {
|
||||
|
||||
// Allocate a w register as all binary operators operate on types of 4 bytes
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Restore the register to make efficient use of registers
|
||||
val (lhsAlloc, rhsAlloc) =
|
||||
if (exprWeight(lhs) > exprWeight(rhs)) {
|
||||
getEvaluatedExpr(lhs, rhs, instructions, true)
|
||||
} else {
|
||||
getEvaluatedExpr(rhs, lhs, instructions, false)
|
||||
}
|
||||
instructions += op(lhsAlloc, lhsAlloc, rhsAlloc)
|
||||
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
(instructions.toList, lhsAlloc)
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates an expression by taking the lhs and rhs
|
||||
* Uses the takeLeft parameter to dictate the order that we return the allocations with
|
||||
*/
|
||||
def getEvaluatedExpr(lhs: Expr, rhs: Expr, instructions: mutable.ListBuffer[Instruction], takeLeft: Boolean)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(PhysicalReg, PhysicalReg) = {
|
||||
|
||||
val (lhsInstr, lhsAlloc) = genExpr(lhs)
|
||||
val (rhsInstr, rhsAlloc) = genExpr(rhs)
|
||||
|
||||
instructions ++= lhsInstr
|
||||
instructions ++= rhsInstr
|
||||
|
||||
if (takeLeft) {
|
||||
(lhsAlloc, rhsAlloc)
|
||||
} else {
|
||||
(rhsAlloc, lhsAlloc)
|
||||
}
|
||||
}
|
||||
|
||||
def exprWeight(expr: Expr): Int = expr match {
|
||||
case binOp: BinOp => Math.max(exprWeight(binOp.lhs), exprWeight(binOp.rhs)) + 1
|
||||
case _ => 0
|
||||
}
|
||||
|
||||
def genArithOp(lhs: Expr, rhs: Expr, op: (Reg, Operand, Operand) => Instruction)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], PhysicalReg) = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val (addsInstr, addsAlloc) = genBinaryOp(lhs, rhs, op)
|
||||
instructions ++= addsInstr
|
||||
instructions += BCond(VS, _errOverflow)
|
||||
(instructions.toList, addsAlloc)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates assembly for a unary operator
|
||||
*/
|
||||
def genUnaryOp(expr: Expr, unOp: UnOp)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], PhysicalReg) = {
|
||||
|
||||
// Allocate a w register as all unary operators operate on types of 4 bytes
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val op = unOp match {
|
||||
case _: Not => EOR.apply(_, _, Imm1)
|
||||
case _: Len => LDUR.apply
|
||||
case _: Neg => NEGS.apply
|
||||
case _: (Ord | Chr) => MoveOps.apply
|
||||
}
|
||||
|
||||
// Restores a temp register as they will collapse onto each other
|
||||
allocator.restore()
|
||||
val (exprInstr, exprAlloc) = genExpr(expr)
|
||||
|
||||
// Free the temp register as it will be the same as the temp
|
||||
allocator.freeTemp(exprAlloc)
|
||||
val tempAlloc = allocator.allocateTemp(W_REG)
|
||||
|
||||
instructions ++= exprInstr
|
||||
instructions += op(tempAlloc, exprAlloc)
|
||||
|
||||
(instructions.toList, tempAlloc)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates assembly for a comparison expression
|
||||
*/
|
||||
def genComparison(lhs: Expr, rhs: Expr, compOp: CompOp)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], PhysicalReg) = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val cond = compOp match {
|
||||
case Gt(lhs, rhs) => GT
|
||||
case Geq(lhs, rhs) => GE
|
||||
case Lt(lhs, rhs) => LT
|
||||
case Leq(lhs, rhs) => LE
|
||||
case Eq(lhs, rhs) => EQ
|
||||
case Neq(lhs, rhs) => NE
|
||||
}
|
||||
|
||||
// Restore a temp register to keep register use efficient
|
||||
allocator.restore()
|
||||
val (lhsInstr, lhsAlloc) = genExpr(lhs)
|
||||
val (rhsInstr, rhsAlloc) = genExpr(rhs)
|
||||
|
||||
// We use a w register as we are dealing with booleans
|
||||
val wLhs = PhysicalReg(lhsAlloc.regNumber, W_REG, lhsAlloc.pool)
|
||||
|
||||
instructions ++= lhsInstr
|
||||
instructions ++= rhsInstr
|
||||
instructions += Compare(lhsAlloc, rhsAlloc)
|
||||
instructions += CSET(wLhs, cond)
|
||||
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
(instructions.toList, wLhs)
|
||||
}
|
||||
|
||||
def genBoolOp(lhs: Expr, rhs: Expr, cond: Flag)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], PhysicalReg) = {
|
||||
allocator.restore()
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val (lhsInstr, lhsAlloc) = genExpr(lhs)
|
||||
val boolLabel = labelGenerator.freshLabel()
|
||||
instructions ++= lhsInstr
|
||||
instructions ++= List(
|
||||
Compare(lhsAlloc, Imm1),
|
||||
BCond(cond, boolLabel)
|
||||
)
|
||||
allocator.freeTemp(lhsAlloc)
|
||||
val (rhsInstr, rhsAlloc) = genExpr(rhs)
|
||||
instructions ++= rhsInstr
|
||||
instructions ++= List(
|
||||
Compare(rhsAlloc, Imm1),
|
||||
Label(boolLabel),
|
||||
CSET(rhsAlloc, EQ)
|
||||
)
|
||||
(instructions.toList, rhsAlloc)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates code for a condition
|
||||
* Function guarantees that Compare instruction is generated, to be used by caller function
|
||||
* Returns a tuple: (instructions, branch condition)
|
||||
*/
|
||||
def genCond(cond: Expr)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], Flag) = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
cond match {
|
||||
case Neq(lhs, rhs) => (generateOpCond(lhs, rhs), NE)
|
||||
case Eq(lhs, rhs) => (generateOpCond(lhs, rhs), EQ)
|
||||
case Not(expr) => (generateOpCond(expr, BoolLiter(IGNORE, true)), NE)
|
||||
case Gt(lhs, rhs) => (generateOpCond(lhs, rhs), GT)
|
||||
case Lt(lhs, rhs) => (generateOpCond(lhs, rhs), LT)
|
||||
case Geq(lhs, rhs) => (generateOpCond(lhs, rhs), GE)
|
||||
case Leq(lhs, rhs) => (generateOpCond(lhs, rhs), LE)
|
||||
|
||||
case expr: (Ident | And | Or) =>
|
||||
val (exprInstr, exprAlloc) = genExpr(expr)
|
||||
instructions ++= exprInstr
|
||||
instructions += Compare(exprAlloc, Imm1)
|
||||
allocator.freeTemp(exprAlloc)
|
||||
(instructions.toList, EQ)
|
||||
|
||||
case x: BoolLiter =>
|
||||
val (lhsInstr, lhsAlloc) = genExpr(cond)
|
||||
|
||||
instructions ++= lhsInstr
|
||||
instructions += Compare(lhsAlloc, Imm1)
|
||||
|
||||
allocator.freeTemp(lhsAlloc)
|
||||
(instructions.toList, EQ)
|
||||
case _ => (List(Comment("CONDITION NOT CREATED")), FAILEDFLAG)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* generates instructions for a comparison expression
|
||||
*/
|
||||
def generateOpCond(lhs: Expr, rhs: Expr)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
List[Instruction] = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val (lhsInstr, lhsAlloc) = genExpr(lhs)
|
||||
val (rhsInstr, rhsAlloc) = genExpr(rhs)
|
||||
instructions ++= lhsInstr
|
||||
instructions ++= rhsInstr
|
||||
instructions += Compare(lhsAlloc, rhsAlloc)
|
||||
|
||||
allocator.freeTemp(lhsAlloc)
|
||||
allocator.freeTemp(rhsAlloc)
|
||||
instructions.toList
|
||||
}
|
||||
471
src/main/wacc/backend/instructions.scala
Normal file
471
src/main/wacc/backend/instructions.scala
Normal file
@ -0,0 +1,471 @@
|
||||
package wacc.backend
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
val MIN_IMM_VAL = -256
|
||||
val MAX_IMM_VAL = 255
|
||||
|
||||
/**
|
||||
* A representation for AArch64 assembly instructions.
|
||||
* Each instruction is wrapped in a case class that knows how to emit its string.
|
||||
*/
|
||||
sealed trait Instruction {
|
||||
def emit: String
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Basic instruction classes
|
||||
// --------------------------
|
||||
|
||||
// Write a comment with default indentation
|
||||
case class Comment(name: String, tabs: String = " ") extends Instruction {
|
||||
override def emit: String = s"$tabs// $name"
|
||||
}
|
||||
|
||||
case class Directive(name: String, tabs: String = "") extends Instruction {
|
||||
override def emit: String = s"$tabs.$name"
|
||||
}
|
||||
|
||||
case class Label(name: String) extends Instruction {
|
||||
override def emit: String = s"$name:"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Data–movement instructions
|
||||
// --------------------------
|
||||
|
||||
// MOV now takes a destination register and a source operand.
|
||||
case class MOV(dest: Reg, src: Operand) extends Instruction {
|
||||
override def emit: String = s" mov ${dest.toOperand}, ${src.toOperand}"
|
||||
}
|
||||
|
||||
case class MoveOps(dest: Reg, src: Operand) extends Instruction {
|
||||
override def emit: String = getInstrs().map(_.emit).mkString("\n")
|
||||
def getInstrs(): List[Instruction] =
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val bitMask = 0xFFFF
|
||||
val shift16 = 16
|
||||
src match {
|
||||
case _: Reg => instructions += MOV(dest, src)
|
||||
case _: LabelOp => () // Label Op shouldn't be in the move (at least not for now?)
|
||||
case Immediate(value) =>
|
||||
val valInt = value.toInt
|
||||
valInt match {
|
||||
// Handles immediate values that are greater than 0xFFFF and breals them into two parts
|
||||
case _ if (valInt >= 0 && valInt < bitMask) => instructions += MOV(dest, src)
|
||||
case _ =>
|
||||
val immOne = Immediate((valInt & bitMask))
|
||||
val immTwo = Immediate(((valInt >> shift16) & bitMask))
|
||||
instructions += MOV(dest, immOne)
|
||||
instructions += MOVK(dest, immTwo, LSL(Imm16))
|
||||
}
|
||||
}
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
case class MOVK(dest: Reg, src: Operand, shift: Shift) extends Instruction {
|
||||
override def emit: String = s" movk ${dest.toOperand}, ${src.toOperand}, ${shift.emit}"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Arithmetic instructions
|
||||
// --------------------------
|
||||
|
||||
case class ADD(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" add ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
case class ADDS(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" adds ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
case class SUB(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" sub ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
case class SUBS(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" subs ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
case class SMULL(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" smull ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
case class SDIV(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" sdiv ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
case class MSUB(dest: Reg, op1: Operand, op2: Operand, op3: Operand) extends Instruction {
|
||||
override def emit: String = s" msub ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}, ${op3.toOperand}"
|
||||
}
|
||||
|
||||
case class NEGS(dest: Reg, src: Operand) extends Instruction {
|
||||
override def emit: String = s" negs ${dest.toOperand}, ${src.toOperand}"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Bitwise instructions
|
||||
// --------------------------
|
||||
|
||||
case class EOR(dest: Reg, op1: Operand, op2: Operand) extends Instruction {
|
||||
override def emit: String = s" eor ${dest.toOperand}, ${op1.toOperand}, ${op2.toOperand}"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Address instructions
|
||||
// --------------------------
|
||||
|
||||
case class ADRP(dest: Reg, label: String) extends Instruction {
|
||||
override def emit: String = s" adrp ${dest.toOperand}, $label"
|
||||
}
|
||||
|
||||
case class ADR(dest: Reg, label: String) extends Instruction {
|
||||
override def emit: String = s" adr ${dest.toOperand}, $label"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Comparison and flag-setting
|
||||
// --------------------------
|
||||
|
||||
case class SXTW() extends Instruction {
|
||||
override def emit: String = "sxtw"
|
||||
}
|
||||
|
||||
case class CMP(op1: Operand, op2: Operand, _sxtw: Option[SXTW] = None) extends Instruction {
|
||||
override def emit: String = _sxtw match {
|
||||
case None => s" cmp ${op1.toOperand}, ${op2.toOperand}"
|
||||
case Some(sxtw) => s" cmp ${op1.toOperand}, ${op2.toOperand}, ${sxtw.emit}"
|
||||
}
|
||||
}
|
||||
|
||||
case class CSET(dest: Reg, condition: Flag) extends Instruction {
|
||||
override def emit: String = s" cset ${dest.toOperand}, ${condition.emit}"
|
||||
}
|
||||
|
||||
case class TST(reg: Reg, op: Operand) extends Instruction {
|
||||
override def emit: String = s" tst ${reg.toOperand}, ${op.toOperand}"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Branch instructions
|
||||
// --------------------------
|
||||
|
||||
case class B(label: String) extends Instruction {
|
||||
override def emit: String = s" b $label"
|
||||
}
|
||||
|
||||
case class BL(label: String, _ignore: String = "") extends Instruction {
|
||||
override def emit: String = s" bl $label"
|
||||
}
|
||||
|
||||
object BL {
|
||||
def apply(label: String)(using labelGenerator: LabelGenerator): BL = {
|
||||
labelGenerator.addWidget(label)
|
||||
new BL(label)
|
||||
}
|
||||
}
|
||||
|
||||
case class BCond(flag: Flag, label: String, _ignore: String = "") extends Instruction {
|
||||
override def emit: String = s" b.${flag.emit} $label"
|
||||
}
|
||||
|
||||
object BCond {
|
||||
def apply(flag: Flag, label: String)(using labelGenerator: LabelGenerator): BCond = {
|
||||
labelGenerator.addWidget(label)
|
||||
new BCond(flag, label)
|
||||
}
|
||||
}
|
||||
|
||||
case class CBZ(reg: Reg, label: String) extends Instruction {
|
||||
override def emit: String = s" cbz ${reg.toOperand}, $label"
|
||||
}
|
||||
|
||||
case class CSEL(dest: Reg, src1: Reg, src2: Reg, condition: Flag) extends Instruction {
|
||||
override def emit: String = s" csel ${dest.toOperand}, ${src1.toOperand}, ${src2.toOperand}, ${condition.emit}"
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Load/store instructions
|
||||
// --------------------------
|
||||
|
||||
// A shift operation
|
||||
sealed trait Shift {
|
||||
def emit: String
|
||||
}
|
||||
case class LSL(amount: Immediate) extends Shift {
|
||||
override def emit: String = s"lsl ${amount.toOperand}"
|
||||
}
|
||||
|
||||
// LDR
|
||||
case class LDR(dest: Reg, base: Reg, offset: Option[Operand] = None, shift: Option[Shift] = None) extends Instruction {
|
||||
override def emit: String = (offset, shift) match {
|
||||
case (None, None) => s" ldr ${dest.toOperand}, [${base.toOperand}]"
|
||||
case (Some(off), None) => s" ldr ${dest.toOperand}, [${base.toOperand}, ${off.toOperand}]"
|
||||
case (Some(off), Some(sft)) => s" ldr ${dest.toOperand}, [${base.toOperand}, ${off.toOperand}, ${sft.emit}]"
|
||||
case _ => s" ldr ${dest.toOperand}, [${base.toOperand}]"
|
||||
}
|
||||
}
|
||||
|
||||
// LDUR
|
||||
case class LDUR(dest: Reg, base: Operand) extends Instruction {
|
||||
override def emit: String = s" ldur ${dest.toOperand}, [${base.toOperand}, #-4]"
|
||||
}
|
||||
|
||||
// LDP
|
||||
case class LDP(dest1: Reg, dest2: Reg, base: Reg) extends Instruction {
|
||||
override def emit: String = s" ldp ${dest1.toOperand}, ${dest2.toOperand}, [${base.toOperand}], #16"
|
||||
}
|
||||
|
||||
// LDRB
|
||||
case class LDRB(dest: Reg, base: Reg, offset: Operand) extends Instruction {
|
||||
override def emit: String = s" ldrb ${dest.toOperand}, [${base.toOperand}, ${offset.toOperand}]"
|
||||
}
|
||||
|
||||
// STP
|
||||
case class STP(src1: Reg, src2: Reg, base: Reg, offset: Option[Operand] = None) extends Instruction {
|
||||
override def emit: String = offset match {
|
||||
case None => s" stp ${src1.toOperand}, ${src2.toOperand}, [${base.toOperand}, #-16]!"
|
||||
case Some(off) => s" stp ${src1.toOperand}, ${src2.toOperand}, [${base.toOperand}, ${off.toOperand}]!"
|
||||
}
|
||||
}
|
||||
|
||||
// STR
|
||||
case class STR(src: Reg, base: Reg, offset: Option[Operand] = None, shift: Option[Shift] = None) extends Instruction {
|
||||
override def emit: String = (offset, shift) match {
|
||||
case (None, None) => s" str ${src.toOperand}, [${base.toOperand}]"
|
||||
case (Some(off), None) => s" str ${src.toOperand}, [${base.toOperand}, ${off.toOperand}]"
|
||||
case (Some(off), Some(sft)) => s" str ${src.toOperand}, [${base.toOperand}, ${off.toOperand}, ${sft.emit}]"
|
||||
case _ => s" str ${src.toOperand}, [${base.toOperand}]"
|
||||
}
|
||||
}
|
||||
|
||||
// STRB
|
||||
case class STRB(src: Operand, base: Reg, offset: Operand) extends Instruction {
|
||||
override def emit: String = s" strb ${src.toOperand}, [${base.toOperand}, ${offset.toOperand}]"
|
||||
}
|
||||
|
||||
// STUR
|
||||
case class STUR(src: Operand, base: Reg) extends Instruction {
|
||||
override def emit: String = s" stur ${src.toOperand}, [${base.toOperand}, #-4]"
|
||||
}
|
||||
|
||||
case class RET() extends Instruction {
|
||||
override def emit: String = " ret"
|
||||
}
|
||||
|
||||
case class EOF() extends Instruction {
|
||||
override def emit: String = ""
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// Bool flags
|
||||
// --------------------------
|
||||
|
||||
sealed trait Flag extends Instruction
|
||||
case object NE extends Flag {
|
||||
override def emit: String = "ne"
|
||||
}
|
||||
|
||||
case object EQ extends Flag {
|
||||
override def emit: String = "eq"
|
||||
}
|
||||
|
||||
case object GT extends Flag {
|
||||
override def emit: String = "gt"
|
||||
}
|
||||
|
||||
case object LT extends Flag {
|
||||
override def emit: String = "lt"
|
||||
}
|
||||
|
||||
case object GE extends Flag {
|
||||
override def emit: String = "ge"
|
||||
}
|
||||
|
||||
case object LE extends Flag {
|
||||
override def emit: String = "le"
|
||||
}
|
||||
|
||||
case object VS extends Flag {
|
||||
override def emit: String = "vs"
|
||||
}
|
||||
case object FAILEDFLAG extends Flag {
|
||||
override def emit: String = "failed"
|
||||
}
|
||||
|
||||
// ----------------------------------------------------
|
||||
// General Instruction Classes
|
||||
// ----------------------------------------------------
|
||||
|
||||
// Contains all classes which perform the generic movement
|
||||
// (To allow for all allocated types to work)
|
||||
|
||||
/**
|
||||
* A class which invokes the correct moving instruction for
|
||||
*
|
||||
*/
|
||||
case class Move(dest: Allocated, src: Allocated)
|
||||
(using allocator: RegisterAllocator) extends Instruction {
|
||||
|
||||
// Emits the correct store instruction depending on the contents of src and dest
|
||||
override def emit: String = getInstrs().map(_.emit).mkString("\n")
|
||||
|
||||
// Functions pertaining instruction allocation
|
||||
def getInstrs(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
dest match {
|
||||
// Check on the whether the operand is on the stack or a register
|
||||
|
||||
// Using either MOV or LDR instructions
|
||||
case regDest: Reg => src match {
|
||||
// Move the Src operand to the Dest register
|
||||
case opSrc: (Reg | Immediate) => instructions += MoveOps(regDest, opSrc)
|
||||
|
||||
// Load the Src register to the Dest register
|
||||
case Spill(offsetSrc, basePointer) => instructions += Load(regDest, basePointer, Some(offsetSrc))
|
||||
}
|
||||
// Using STR instructions
|
||||
case Spill(offsetDest, basePointer) => src match {
|
||||
case regSrc: (Reg) =>
|
||||
instructions += Store(regSrc, basePointer, Some(offsetDest))
|
||||
|
||||
// TODO: Might need to add logic for tempReg Spill creation so that there's always a free
|
||||
// temp alloc to move things to
|
||||
// We will Load then store here
|
||||
case Spill(offsetSrc, basePointer) =>
|
||||
// Create a temp x bit-mode register
|
||||
val tempAlloc = allocator.allocateTemp(X_REG)
|
||||
tempAlloc match {
|
||||
case reg: Reg =>
|
||||
instructions += Load(reg, basePointer, Some(offsetDest))
|
||||
instructions += Store(reg, basePointer, Some(offsetSrc))
|
||||
}
|
||||
allocator.freeTemp(tempAlloc) // Free it (location doesn't matter)
|
||||
}
|
||||
}
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A class which invokes the correct storing instruction
|
||||
*
|
||||
*/
|
||||
case class Store(reg1: Reg, alloc2: Allocated, _imm3: Option[Immediate] = None)
|
||||
(using allocator: RegisterAllocator) extends Instruction {
|
||||
|
||||
// Emits the correct store instruction depending on the contents of src and dest
|
||||
override def emit: String = getInstrs().map(_.emit).mkString("\n")
|
||||
|
||||
def getInstrs(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val regTemp = allocator.allocateTemp(X_REG)
|
||||
_imm3 match {
|
||||
case None => alloc2 match {
|
||||
case reg2: Reg => instructions += STR(reg1, reg2)
|
||||
case Spill(offset2, basePointer) => instructions += STR(reg1, basePointer, Some(offset2))
|
||||
}
|
||||
case Some(imm3) => alloc2 match {
|
||||
case reg2: Reg => imm3 match {
|
||||
|
||||
// If the offset is less than -256, then we need to move the value to a register first
|
||||
case Immediate(negVal) if negVal.toInt < -256 =>
|
||||
instructions += MOV(X17, imm3)
|
||||
instructions += STR(reg1, reg2, Some(X17))
|
||||
|
||||
// If the offset is greater than 255, then we need to move the value to a register first
|
||||
case Immediate(posVal) if posVal.toInt > 255 =>
|
||||
instructions += MoveOps(X17, imm3)
|
||||
instructions += STR(reg1, reg2, Some(X17))
|
||||
case _ =>
|
||||
instructions += STR(reg1, reg2, Some(imm3))
|
||||
|
||||
}
|
||||
case Spill(offset2, basePointer) =>
|
||||
instructions += Load(regTemp, basePointer, Some(offset2))
|
||||
instructions += STR(reg1, regTemp, Some(imm3))
|
||||
}
|
||||
}
|
||||
|
||||
allocator.freeTemp(regTemp)
|
||||
instructions.toList
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A class which invokes the correct compare instruction
|
||||
*
|
||||
*/
|
||||
case class Compare(alloc1: Allocated, op2: Operand, _sxtw: Option[SXTW] = None)
|
||||
(using allocator: RegisterAllocator) extends Instruction {
|
||||
override def emit: String = getInstrs().map(_.emit).mkString("\n")
|
||||
|
||||
def getInstrs(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
alloc1 match {
|
||||
case reg: Reg => instructions += CMP(reg, op2, _sxtw)
|
||||
case Spill(offset, basePointer) =>
|
||||
val regTemp = op2 match {
|
||||
case reg: Reg if reg.bitMode == W_REG => allocator.allocateTemp(W_REG)
|
||||
case _: Immediate => allocator.allocateTemp(W_REG)
|
||||
case _ => allocator.allocateTemp(X_REG)
|
||||
}
|
||||
|
||||
instructions += LDR(regTemp, basePointer, Some(offset))
|
||||
instructions += CMP(regTemp, op2, _sxtw)
|
||||
allocator.freeTemp(regTemp)
|
||||
}
|
||||
instructions.toList
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A class which invokes the correct load instruction
|
||||
*
|
||||
*/
|
||||
case class Load(dest: Reg, base: Allocated, offset: Option[Operand] = None, shift: Option[Shift] = None)
|
||||
(using allocator: RegisterAllocator) extends Instruction {
|
||||
override def emit: String = getInstrs().map(_.emit).mkString("\n")
|
||||
|
||||
def getInstrs(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
base match {
|
||||
case reg: Reg =>
|
||||
offset match {
|
||||
|
||||
// If the offset is less than -256, then we need to move the value to a register first
|
||||
case Some(negImm@Immediate(negVal)) if negVal.toInt < -256 =>
|
||||
instructions += MOV(X17, negImm)
|
||||
instructions += LDR(dest, reg, Some(X17), shift)
|
||||
|
||||
// If the offset is greater than 255, then we need to move the value to a register first
|
||||
case Some(posImm@Immediate(posVal)) if posVal.toInt > 255 =>
|
||||
instructions += MoveOps(X17, posImm)
|
||||
instructions += LDR(dest, reg, Some(X17), shift)
|
||||
|
||||
// Otherwise, just load the values as is
|
||||
case _ => instructions += LDR(dest, reg, offset, shift)
|
||||
|
||||
}
|
||||
|
||||
case Spill(offsetBase, basePointer) =>
|
||||
|
||||
offset match {
|
||||
|
||||
// We load twice if the offset exists
|
||||
case Some(_) =>
|
||||
val tempReg = allocator.allocateTemp(X_REG)
|
||||
instructions += LDR(tempReg, basePointer, Some(offsetBase))
|
||||
instructions += LDR(dest, tempReg, offset, shift)
|
||||
allocator.freeTemp(tempReg)
|
||||
|
||||
// We load straight to the destination
|
||||
case None => instructions += LDR(dest, basePointer, Some(offsetBase))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
}
|
||||
399
src/main/wacc/backend/labelGen.scala
Normal file
399
src/main/wacc/backend/labelGen.scala
Normal file
@ -0,0 +1,399 @@
|
||||
package wacc.backend
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
// All libc definitions
|
||||
val scanf = "scanf"
|
||||
val exit = "exit"
|
||||
val free = "free"
|
||||
val puts = "puts"
|
||||
val printf = "printf"
|
||||
val fflush = "fflush"
|
||||
val malloc = "malloc"
|
||||
|
||||
// All Standard Label definitions
|
||||
val _prints = "_prints"
|
||||
val _printc = "_printc"
|
||||
val _printi = "_printi"
|
||||
val _printb = "_printb"
|
||||
val _printp = "_printp"
|
||||
val _println = "_println"
|
||||
val _readi = "_readi"
|
||||
val _readc = "_readc"
|
||||
val _malloc = "_malloc"
|
||||
val _errOutOfMemory = "_errOutOfMemory"
|
||||
val _errOutOfBounds = "_errOutOfBounds"
|
||||
val _errBadChar = "_errBadChar"
|
||||
val _errOverflow = "_errOverflow"
|
||||
val _errDivZero = "_errDivZero"
|
||||
val _errNull = "_errNull"
|
||||
def _arrLoad(size: String) = s"_arrLoad$size"
|
||||
def _arrStore(size: String) = s"_arrStore$size"
|
||||
val _freepair = "_freepair"
|
||||
|
||||
/**
|
||||
* LabelGenerator provides unique labels used for branch targets
|
||||
* We use a simple counter that is incremented on every request
|
||||
* Also produces standard labels (for read, print, and any other call that may be needed)
|
||||
*/
|
||||
class LabelGenerator {
|
||||
private var branchCounter = -1
|
||||
private var stringCounter = -1
|
||||
|
||||
private def prologue(label: String, reg1: Reg, reg2: Reg): List[Instruction] = List (
|
||||
Label(label),
|
||||
Comment(s"push {${reg1.toOperand}${if (reg2 != XZR) s", ${reg2.toOperand}" else ""}}"),
|
||||
STP(reg1, reg2, SP)
|
||||
)
|
||||
|
||||
private def epilogue(reg1: Reg, reg2: Reg): List[Instruction] = List (
|
||||
Comment(s"pop {${reg1.toOperand}${if (reg2 != XZR) s", ${reg2.toOperand}" else ""}}"),
|
||||
LDP(reg1, reg2, SP),
|
||||
RET()
|
||||
)
|
||||
|
||||
// Maps each label name to its full instruction set
|
||||
private val widgetMap: Map[String, List[Instruction]] = Map(
|
||||
_prints -> printInstr('s'), _printc -> printInstr('c'), _printi -> printInstr('i'),
|
||||
_printb -> printInstr('b'), _printp -> printInstr('p'), _println -> printInstr('l'),
|
||||
_readi -> readInstr('i'), _readc -> readInstr('c'), _malloc -> mallocInstr(),
|
||||
_errOutOfMemory -> errInstr('m'), _errOutOfBounds -> errInstr('b'),
|
||||
_errBadChar -> errInstr('c'), _errOverflow -> errInstr('o'), _errDivZero -> errInstr('z'),
|
||||
_errNull -> errInstr('n'), _arrLoad("1") -> arrNInstr("Load", '1'),
|
||||
_arrLoad("4") -> arrNInstr("Load", '4'), _arrLoad("8") -> arrNInstr("Load", '8'),
|
||||
_arrStore("1") -> arrNInstr("Store", '1'), _arrStore("4") -> arrNInstr("Store", '4'),
|
||||
_arrStore("8") -> arrNInstr("Store", '8'), _freepair -> freePairInstr()
|
||||
)
|
||||
|
||||
// Contains the unique standard labels that have been called through the program
|
||||
private val widgetSet: mutable.Set[String] = mutable.Set.empty
|
||||
|
||||
// A map to hold all string literals (so that duplicates reuse the same label)
|
||||
private val stringLiterals: mutable.Map[String, (String, String)] = mutable.Map.empty
|
||||
|
||||
def getStringLiterals(): List[(String, (String, String))] = stringLiterals.toSeq.sortBy {
|
||||
// Sort by the number of the string label
|
||||
_._2._2.stripPrefix(".L.str").toInt
|
||||
}.toList
|
||||
|
||||
/**
|
||||
* Escapes special characters in a string literal
|
||||
*/
|
||||
def escapeString(s: String): String = {
|
||||
s.flatMap {
|
||||
case '\n' => "\\n"
|
||||
case '\t' => "\\t"
|
||||
case '\r' => "\\r"
|
||||
case '\"' => "\\\""
|
||||
case '\\' => "\\\\"
|
||||
case c => c.toString
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates all of the standard labels, this depends on how populated the set is
|
||||
*/
|
||||
def genWidgets(): List[Instruction] = {
|
||||
widgetSet.toSeq.sortBy(str => str).toList.flatMap(EOF() +: widgetMap.get(_).get)
|
||||
}
|
||||
|
||||
def addWidget(label: String): Unit = {
|
||||
if (widgetMap.contains(label)) {
|
||||
widgetSet += label
|
||||
label match {
|
||||
// Add the errOutOfMemory and prints labels for a malloc
|
||||
case s"$mallocLabel" if mallocLabel == _malloc =>
|
||||
widgetSet += _errOutOfMemory
|
||||
widgetSet += _prints
|
||||
|
||||
// Add the errNull and prints labels for a freepair
|
||||
case s"$freepairLabel" if freepairLabel == _freepair =>
|
||||
widgetSet += _errNull
|
||||
widgetSet += _prints
|
||||
|
||||
// For overflow, div by zer, or null errors, add the prints label
|
||||
case s"$errLabel" if errLabel == _errOverflow || errLabel == _errDivZero || errLabel == _errNull =>
|
||||
widgetSet += _prints
|
||||
|
||||
// Matches on the 3 arrLoad/arrStore standard labels and adds the _errOutOfBounds label to the set
|
||||
case label if (label.startsWith(_arrLoad("")) || label.startsWith(_arrStore(""))) =>
|
||||
widgetSet += _errOutOfBounds
|
||||
|
||||
case _ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def freshLabel(str: String): String = {
|
||||
// If the map doesn't contain the string, then generate a new label
|
||||
if (!stringLiterals.contains(str)) {
|
||||
stringCounter += 1
|
||||
stringLiterals += (str -> (escapeString(str), s".L.str$stringCounter"))
|
||||
}
|
||||
|
||||
// Return the label
|
||||
stringLiterals.get(str).get._2
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a fresh label for a branching instruction (while/if-else)
|
||||
* Returns the label string
|
||||
*/
|
||||
def freshLabel(): String = {
|
||||
branchCounter += 1
|
||||
s".L${branchCounter}"
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Standard labels
|
||||
// ----------------------------------------------------
|
||||
|
||||
/**
|
||||
* Formats print label instructions for chars, strings, ints, pointers, booleans, and println
|
||||
*/
|
||||
def printInstr(_type: Char): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val typeToEsc = Map('p' -> "%p", 'c' -> "%c", 'i' -> "%d")
|
||||
val typePrinted = if (_type == 'l') "ln" else _type
|
||||
|
||||
instructions += Comment(s"length of .L._print${typePrinted}_str0", "")
|
||||
// Add the meta data for the correct print type
|
||||
_type match {
|
||||
case chr @ ('p' | 'c' | 'i') =>
|
||||
instructions ++= List(
|
||||
Directive("word 2", " "),
|
||||
Label(s".L._print${chr}_str0"),
|
||||
Directive(s"asciz \"${typeToEsc.get(chr).get}\"", " ")
|
||||
)
|
||||
case 's' =>
|
||||
instructions ++= List(
|
||||
Directive("word 4", " "),
|
||||
Label(s".L._prints_str0"),
|
||||
Directive(s"asciz \"%.*s\"", " ")
|
||||
)
|
||||
case 'b' =>
|
||||
instructions ++= List(
|
||||
Directive("word 5", " "),
|
||||
Label(s".L._printb_str0"),
|
||||
Directive(s"asciz \"false\"", " "),
|
||||
Comment(s"length of .L._printb_str1", ""),
|
||||
Directive("word 4", " "),
|
||||
Label(s".L._printb_str1"),
|
||||
Directive(s"asciz \"true\"", " "),
|
||||
Comment(s"length of .L._printb_str2", ""),
|
||||
Directive("word 4", " "),
|
||||
Label(s".L._printb_str2"),
|
||||
Directive(s"asciz \"%.*s\"", " ")
|
||||
)
|
||||
|
||||
case 'l' =>
|
||||
instructions ++= List(
|
||||
Directive("word 0", " "),
|
||||
Label(s".L._println_str0"),
|
||||
Directive(s"asciz \"\"", " ")
|
||||
)
|
||||
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
// Common preamble for all print instructions
|
||||
instructions += Directive("align 4")
|
||||
instructions ++= prologue(s"_print${typePrinted}", LR, XZR)
|
||||
|
||||
// Generate the address manipulation depending on the type of print
|
||||
_type match {
|
||||
case 'p' | 'c' | 'i' =>
|
||||
instructions += MoveOps(X1, X0)
|
||||
|
||||
case 's' =>
|
||||
instructions += MoveOps(X2, X0)
|
||||
instructions += LDUR(W1, X0)
|
||||
|
||||
case 'b' =>
|
||||
instructions ++= List(
|
||||
CMP(W0, Imm0),
|
||||
new BCond(NE, ".L_printb0"),
|
||||
ADR(X2, s".L._printb_str0"),
|
||||
B(".L_printb1"),
|
||||
Label(".L_printb0"),
|
||||
ADR(X2, s".L._printb_str1"),
|
||||
Label(".L_printb1"),
|
||||
LDUR(W1, X2)
|
||||
)
|
||||
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
// Generate the remaining routine depending on the type of print
|
||||
if (_type == 'b') {
|
||||
instructions += ADR(X0, s".L._printb_str2")
|
||||
} else {
|
||||
instructions += ADR(X0, s".L._print${typePrinted}_str0")
|
||||
}
|
||||
if (_type == 'l') {
|
||||
instructions += new BL(puts)
|
||||
} else {
|
||||
instructions += new BL(printf)
|
||||
}
|
||||
|
||||
instructions += MoveOps(X0, Imm0)
|
||||
instructions += new BL(fflush)
|
||||
instructions ++= epilogue(LR, XZR)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats a standard read label instruction for chars, strings, ints, pointers, and booleans
|
||||
*/
|
||||
def readInstr(_type: Char): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val label = s".L._read${_type}_str0"
|
||||
|
||||
instructions += Comment(s"length of $label", "")
|
||||
|
||||
_type match {
|
||||
case 'i' =>
|
||||
instructions ++= List(
|
||||
Directive("word 2", " "),
|
||||
Label(label),
|
||||
Directive(s"asciz \"%d\"", " ")
|
||||
)
|
||||
|
||||
case 'c' =>
|
||||
instructions ++= List(
|
||||
Directive("word 3", " "),
|
||||
Label(label),
|
||||
Directive(s"asciz \" %c\"", " ")
|
||||
)
|
||||
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
instructions += Directive("align 4")
|
||||
instructions ++= prologue(label.replace(".L.", "").replace("_str0", ""), X0, LR)
|
||||
instructions ++= List(
|
||||
MoveOps(X1, SP),
|
||||
ADR(X0, label),
|
||||
new BL(scanf),
|
||||
)
|
||||
|
||||
instructions ++= epilogue(X0, LR)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
def mallocInstr(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
instructions ++= prologue(_malloc, LR, XZR)
|
||||
instructions ++= List(
|
||||
new BL(malloc),
|
||||
CBZ(X0, _errOutOfMemory),
|
||||
)
|
||||
|
||||
instructions ++= epilogue(LR, XZR)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
def errInstr(errType: Char): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Create variables for the name and word length for the type
|
||||
val (errName, errWord, errMsg) = errType match {
|
||||
case 'm' => (_errOutOfMemory, "word 27", "out of memory")
|
||||
case 'b' => (_errOutOfBounds, "word 42", "array index %d out of bounds")
|
||||
case 'c' => (_errBadChar, "word 50", "int %d is not ascii character 0-127")
|
||||
case 'o' => (_errOverflow, "word 52", "integer overflow or underflow occurred")
|
||||
case 'z' => (_errDivZero, "word 40", "division or modulo by zero")
|
||||
case 'n' => (_errNull, "word 45", "null pair dereferenced or freed")
|
||||
case _ => ("", "", "")
|
||||
}
|
||||
|
||||
// Add the preamble for instructions
|
||||
instructions ++= List(
|
||||
Comment(s"length of .L.${errName}_str0", ""),
|
||||
Directive(errWord, " "),
|
||||
Label(s".L.${errName}_str0"),
|
||||
Directive(s"asciz \"fatal error: ${errMsg}\\n\"", " "),
|
||||
Directive("align 4"),
|
||||
Label(errName),
|
||||
ADR(X0, s".L.${errName}_str0")
|
||||
)
|
||||
|
||||
val remainingInstr = errType match {
|
||||
case ('m' | 'o' | 'z' | 'n') => List(
|
||||
new BL(_prints)
|
||||
)
|
||||
case ('b' | 'c') => List(
|
||||
new BL(printf),
|
||||
MoveOps(X0, Imm0),
|
||||
new BL(fflush)
|
||||
)
|
||||
case _ => List()
|
||||
}
|
||||
|
||||
instructions ++= remainingInstr
|
||||
|
||||
// Add on the final common part for all error labels
|
||||
instructions ++= List(
|
||||
MoveOps(W0, Imm_1),
|
||||
new BL(exit)
|
||||
)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
|
||||
def arrNInstr(_type: String, n :Char): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
_type match {
|
||||
case "Store" => Comment(s"Special calling convention: array ptr passed in X7, index in X17," +
|
||||
" value to store in X8, LR(W30) is used as general register")
|
||||
case "Load" => Comment("Special calling convention: array ptr passed in X7, index in X17," +
|
||||
" LR(W30) is used as general register, and return into X7")
|
||||
}
|
||||
instructions ++= prologue(s"_arr$_type$n", LR, XZR)
|
||||
instructions ++= List(
|
||||
CMP(W17, Imm0),
|
||||
Comment("this must be a 64-bit move so that it doesn't truncate if the move fails"),
|
||||
CSEL(X1, X17, X1, LT),
|
||||
new BCond(LT, _errOutOfBounds),
|
||||
LDUR(W30, X7),
|
||||
CMP(W17, W30),
|
||||
Comment("this must be a 64-bit move so that it doesn't truncate if the move fails"),
|
||||
CSEL(X1, X17, X1, GE),
|
||||
new BCond(GE, _errOutOfBounds)
|
||||
)
|
||||
|
||||
_type match {
|
||||
case "Load" => n match {
|
||||
case '8' => instructions += LDR(X7, X7, Some(X17), Some(LSL(Imm3)))
|
||||
case '4' => instructions += LDR(W7, X7, Some(X17), Some(LSL(Imm2)))
|
||||
case '1' => instructions += LDRB(W7, X7, X17)
|
||||
case _ => ()
|
||||
}
|
||||
case "Store" => n match {
|
||||
case '8' => instructions += STR(X8, X7, Some(X17), Some(LSL(Imm3)))
|
||||
case '4' => instructions += STR(W8, X7, Some(X17), Some(LSL(Imm2)))
|
||||
case '1' => instructions += STRB(W8, X7, X17)
|
||||
case _ => ()
|
||||
}
|
||||
}
|
||||
|
||||
instructions ++= epilogue(LR, XZR)
|
||||
instructions.toList
|
||||
|
||||
}
|
||||
|
||||
def freePairInstr(): List[Instruction] = {
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
instructions ++= prologue(_freepair, LR, XZR)
|
||||
instructions += CBZ(X0, _errNull)
|
||||
instructions += new BL(free)
|
||||
instructions ++= epilogue(LR, XZR)
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
}
|
||||
311
src/main/wacc/backend/lhsRhsGen.scala
Normal file
311
src/main/wacc/backend/lhsRhsGen.scala
Normal file
@ -0,0 +1,311 @@
|
||||
package wacc.backend
|
||||
import wacc.frontend.syntax.ast._
|
||||
import wacc.frontend.semantic.environment._
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* Generates assembly for a RHS
|
||||
*/
|
||||
def genRhs(rhs: Rhs, cGen: CodeGenerator)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator,
|
||||
gFuncEnv: GlobalFuncEnv, funcRegList: List[(String, Allocated)] = List.empty):
|
||||
(List[Instruction], Reg) = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
rhs match {
|
||||
case expr: Expr => genExpr(expr)
|
||||
|
||||
case NewPair(_, fst, snd) =>
|
||||
|
||||
// Save argument registers
|
||||
instructions ++= cGen.generatePushInstructions(funcRegList, callingFunc = true)
|
||||
|
||||
instructions ++= List(
|
||||
MoveOps(W0, Imm16),
|
||||
BL(_malloc),
|
||||
MoveOps(X16, X0)
|
||||
)
|
||||
|
||||
// Pop off argument registers
|
||||
instructions ++= cGen.generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
|
||||
|
||||
val (fstInstr, fstAlloc) = genExpr(fst)
|
||||
val xFst = PhysicalReg(fstAlloc.regNumber, X_REG, fstAlloc.pool)
|
||||
instructions ++= fstInstr
|
||||
instructions += STR(xFst, X16)
|
||||
allocator.freeTemp(fstAlloc)
|
||||
|
||||
val (sndInstr, sndAlloc) = genExpr(snd)
|
||||
val xSnd = PhysicalReg(sndAlloc.regNumber, X_REG, sndAlloc.pool)
|
||||
instructions ++= sndInstr
|
||||
instructions += STR(xSnd, X16, Some(Imm8))
|
||||
allocator.freeTemp(sndAlloc)
|
||||
|
||||
val pairAlloc = allocator.allocateTemp(X_REG)
|
||||
instructions += MoveOps(pairAlloc, X16)
|
||||
(instructions.toList, pairAlloc)
|
||||
|
||||
case arr@ArrayLiter(_, elems) =>
|
||||
|
||||
// Save argument registers
|
||||
instructions ++= cGen.generatePushInstructions(funcRegList, callingFunc = true)
|
||||
|
||||
// Allocate the first temp
|
||||
val tempAlloc1 = allocator.allocateTemp(W_REG)
|
||||
|
||||
// Create the immediate that is to be used for mallocing to the result register
|
||||
val typeSize = sizeof(exprType(elems.headOption.getOrElse(PairLiter(IGNORE)))(using gMainEnv))
|
||||
val arrImm = Immediate((typeSize * elems.length) + 4)
|
||||
val elemsImm = Immediate(elems.length)
|
||||
|
||||
// Add the preamble for adding elements to the array
|
||||
instructions ++= List(
|
||||
Comment(s"${elems.length} element array"),
|
||||
MoveOps(W0, arrImm),
|
||||
BL(_malloc),
|
||||
MoveOps(X16, X0)
|
||||
)
|
||||
|
||||
// Pop off argument registers
|
||||
instructions ++= cGen.generatePopInstructions(funcRegList.reverse.toList, callingFunc = true)
|
||||
|
||||
instructions ++= List(
|
||||
Comment("array pointers are shifted forwards by 4 bytes (to account for size)"),
|
||||
ADD(X16, X16, Imm4),
|
||||
MoveOps(tempAlloc1, elemsImm),
|
||||
STUR(tempAlloc1, X16)
|
||||
)
|
||||
allocator.freeTemp(tempAlloc1)
|
||||
|
||||
// Add each element to the memory location (depending on the type of the element)
|
||||
instructions ++= elems.zipWithIndex.flatMap{(elem: Expr, index: Int) =>
|
||||
|
||||
// Allocate the temp allocation for both bit-modes as they are both needed
|
||||
val (exprInstr, exprAlloc) = genExpr(elem)
|
||||
val (exprAllocX, exprAllocW) = exprAlloc match {
|
||||
case xReg@PhysicalReg(regNumber, bitMode, pool) =>
|
||||
(xReg, PhysicalReg(regNumber, W_REG, pool))
|
||||
}
|
||||
|
||||
// Add the specific element type
|
||||
val indexImm = Immediate(index * typeSize)
|
||||
val specificInstr = exprType(elem)(using gMainEnv) match {
|
||||
case IntType => List(
|
||||
STR(exprAllocW, X16, Some(indexImm))
|
||||
)
|
||||
case (CharType | BoolType) => List(
|
||||
STRB(exprAllocW, X16, indexImm)
|
||||
)
|
||||
case _ => List(
|
||||
STR(exprAllocX, X16, Some(indexImm))
|
||||
)
|
||||
}
|
||||
|
||||
allocator.freeTemp(exprAlloc)
|
||||
exprInstr ++ specificInstr
|
||||
}
|
||||
|
||||
// The final array is always in an x bit-mode register
|
||||
val tempAlloc = allocator.allocateTemp(X_REG)
|
||||
instructions += MoveOps(tempAlloc, X16)
|
||||
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
case Call(_, id, args) =>
|
||||
instructions += Comment("FUNCTION CALL")
|
||||
|
||||
// Store parameter args onto stack before moving values
|
||||
instructions ++= cGen.generatePushInstructions(funcRegList, callingFunc = true)
|
||||
|
||||
// Add the needed stack space for extra arguments (args 9 onwards)
|
||||
val _addedStackSpace = ((args.size - MAX_PARAM_REGS + 1) & ~1) * EIGHT_BYTES
|
||||
val addedStackSpace = Immediate(_addedStackSpace)
|
||||
if (_addedStackSpace > 0) {
|
||||
instructions += SUB(SP, SP, addedStackSpace)
|
||||
}
|
||||
|
||||
// Efficiently retrieves and argument and then moves the content to the correct arg register
|
||||
instructions ++= args.zipWithIndex.flatMap{ case (arg, index) =>
|
||||
val instrs: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
|
||||
// Get the instructions and allocation for the argument
|
||||
val (argInstr, argAlloc) = genExpr(arg)
|
||||
|
||||
// Create an allocation for a parameter, this will either be a register or memory on the stack
|
||||
val paramAlloc = index match {
|
||||
case paramReg if paramReg < MAX_PARAM_REGS => Register(s"$index", argAlloc.bitMode)
|
||||
case paramSpill =>
|
||||
val _offset = (paramSpill - MAX_PARAM_REGS) * EIGHT_BYTES
|
||||
val offset = Immediate(_offset)
|
||||
Spill(offset, SP)
|
||||
}
|
||||
|
||||
// Add the argument's instructions and then move the argument's temp to the corresponding parameter reg
|
||||
// Restoration of registers occurs in the declaration and assignment statements
|
||||
instrs ++= argInstr
|
||||
instrs += Move(paramAlloc, argAlloc)
|
||||
allocator.freeTemp(argAlloc)
|
||||
|
||||
instrs.toList
|
||||
}
|
||||
|
||||
// Call the function
|
||||
|
||||
// The id of the function used to query the global function environment
|
||||
val funcKey = id.id
|
||||
|
||||
// Replace # with an underscore for funcLabel
|
||||
val funcLabel = id.id.replaceAll("#", "_")
|
||||
instructions += BL(funcLabel)
|
||||
|
||||
// Offload all stack manipulation for extra arguments
|
||||
if (_addedStackSpace > 0) {
|
||||
instructions += ADD(SP, SP, addedStackSpace)
|
||||
}
|
||||
|
||||
// Get lhs type
|
||||
val lhsType = gFuncEnv.lookup(funcKey).getOrElse(
|
||||
throw new RuntimeException(s"Function $funcKey not found in global function environment, couldnt get lhs")
|
||||
)._type
|
||||
|
||||
val returnReg = lhsType match {
|
||||
case (IntType | CharType | BoolType) => W0
|
||||
case _ => X0
|
||||
}
|
||||
|
||||
// Registers will be popped after declaration / assignment to avoid overwriting X0
|
||||
(instructions.toList, returnReg)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates assembly for a LHS
|
||||
*/
|
||||
def genLhs(lhs: Lhs)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator, funcRegList: List[(String, Allocated)] = List.empty):
|
||||
(List[Instruction], Allocated) = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
lhs match {
|
||||
// Differs to the expr identifier parsing as it just releases the result
|
||||
// Needed to distinguish between assignments and just using the variable
|
||||
case Ident(_, id) =>
|
||||
val idAlloc = allocator.getAllocMapping(id)
|
||||
(instructions.toList, idAlloc)
|
||||
case PairElem(_, selector, expr) =>
|
||||
val (exprInstr, exprAlloc) = genLhs(expr)
|
||||
instructions += Compare(exprAlloc, Imm0)
|
||||
instructions += BCond(EQ, _errNull)
|
||||
instructions ++= exprInstr
|
||||
(instructions.toList, exprAlloc)
|
||||
|
||||
// All other LHSs call genExpr and special cases are handled in the calling statement
|
||||
case expr: Expr => genExpr(expr)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates assembly for a LHS considering special cases for Pair and Array Elems(when storing)
|
||||
*/
|
||||
def genLhsForUsing(lhs: Lhs, parentInstrs: mutable.ListBuffer[Instruction])
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
(List[Instruction], Allocated) = {
|
||||
|
||||
lhs match {
|
||||
// In the case of an ArrayElem, you take the first n-1 indices
|
||||
// where n is the length of the indices list
|
||||
case ArrayElem(pos, id, indices) =>
|
||||
val shortenedIndices = indices.take(indices.length - 1)
|
||||
val (instr, tempAlloc) = genLhs(ArrayElem(pos, id, shortenedIndices))
|
||||
|
||||
// If the shortened query is nothing, then move the identifier's address
|
||||
// to the address that's been generated
|
||||
if (shortenedIndices.isEmpty) {
|
||||
val idAlloc = allocator.getAllocMapping(id.id)
|
||||
parentInstrs += Move(tempAlloc, idAlloc)
|
||||
}
|
||||
|
||||
// Return the instructions and allocation generated
|
||||
// instr is empty in the case of indicies being an empty list
|
||||
(instr, tempAlloc)
|
||||
|
||||
// In the case of a PairElem, you need to unfold the PairElem to get the list of selectors
|
||||
// and then iteratively load the values from the pair in reverse order
|
||||
case _: PairElem =>
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
val (selectorList, finalIdent) = unfoldPairElem(lhs)
|
||||
val (lhsInstr, lhsAlloc) = genLhs(finalIdent)
|
||||
var tempAlloc = lhsAlloc
|
||||
val reverse = selectorList.drop(1).reverse // Top case is handled in genStmt Assign
|
||||
|
||||
if (reverse.isEmpty) {
|
||||
instructions += Compare(tempAlloc, Imm0)
|
||||
instructions += BCond(EQ, _errNull)
|
||||
} // We'll need to remove duplication of code between here & genLhs
|
||||
|
||||
instructions ++= lhsInstr
|
||||
reverse.foreach { selector =>
|
||||
val tempAlloc2 = allocator.allocateTemp(X_REG)
|
||||
val offset = selector match {
|
||||
case Fst => Some(Imm0)
|
||||
case Snd => Some(Imm8)
|
||||
}
|
||||
instructions += Load(tempAlloc2, tempAlloc, offset)
|
||||
allocator.freeTemp(tempAlloc)
|
||||
instructions += Compare(tempAlloc, Imm0)
|
||||
instructions += BCond(EQ, _errNull)
|
||||
tempAlloc = tempAlloc2
|
||||
}
|
||||
(instructions.toList, tempAlloc)
|
||||
|
||||
|
||||
case other => genLhs(lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Unfolds a PairElem to get the list of selectors and the final LHS
|
||||
*/
|
||||
def unfoldPairElem(lhs: Lhs): (List[PairSelector], Lhs) = lhs match {
|
||||
case PairElem(_, selector, innerLhs) =>
|
||||
val (selectors, finalLhs) = unfoldPairElem(innerLhs)
|
||||
(selector :: selectors, finalLhs)
|
||||
case _ => (Nil, lhs)
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a move of the rhs to the lhs depending on the type of LHS
|
||||
*/
|
||||
def lhsMove(lhs: Lhs, lhsAlloc: Allocated, rhsAlloc: Reg, reg8: Reg)
|
||||
(using gMainEnv: GlobalMainEnv, allocator: RegisterAllocator, labelGenerator: LabelGenerator):
|
||||
List[Instruction] = {
|
||||
|
||||
val instructions: mutable.ListBuffer[Instruction] = mutable.ListBuffer.empty
|
||||
lhs match {
|
||||
// Call on _arrayStoreN function where N is the size of the type
|
||||
case arrElem@ArrayElem(_, id, indices) =>
|
||||
|
||||
// Retrieve the final indexes instruction and allocation
|
||||
val (finalIndexInstr, finalIndexAlloc) = genExpr(indices.last)
|
||||
instructions ++= finalIndexInstr
|
||||
instructions ++= List(
|
||||
MoveOps(W17, finalIndexAlloc),
|
||||
MoveOps(reg8, rhsAlloc),
|
||||
Move(X7, lhsAlloc),
|
||||
BL(_arrStore(s"${sizeof(lhsType(arrElem))}"))
|
||||
)
|
||||
|
||||
allocator.freeTemp(finalIndexAlloc)
|
||||
|
||||
case PairElem(_, selector, expr) =>
|
||||
selector match {
|
||||
case Fst => instructions += Store(rhsAlloc, lhsAlloc)
|
||||
case Snd => instructions += Store(rhsAlloc, lhsAlloc, Some(Imm8))
|
||||
}
|
||||
|
||||
// Otherwise, move the contents of temp allocation to identifier allocation
|
||||
case _ => instructions += Move(lhsAlloc, rhsAlloc)
|
||||
}
|
||||
|
||||
instructions.toList
|
||||
}
|
||||
288
src/main/wacc/backend/registerAllocator.scala
Normal file
288
src/main/wacc/backend/registerAllocator.scala
Normal file
@ -0,0 +1,288 @@
|
||||
package wacc.backend
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Constants
|
||||
// ----------------------------------------------------
|
||||
|
||||
// Constant for the max number of variables that can be stored in a callee reg
|
||||
val MAX_CALLEE_REGS = CalleePool.available.length
|
||||
val MAX_PARAM_REGS = 8
|
||||
|
||||
// Byte constants
|
||||
val SIXTEEN_BYTES = 16
|
||||
val EIGHT_BYTES = 8
|
||||
|
||||
val X_REG = 'x'
|
||||
val W_REG = 'w'
|
||||
val FP_REG = "fp"
|
||||
val SP_REG = "sp"
|
||||
val LR_REG = "lr"
|
||||
val XZR_REG = "xzr"
|
||||
|
||||
// Registers that are used through the program (particularly for standard labels)
|
||||
val X0 = Register("0", X_REG)
|
||||
val W0 = Register("0", W_REG)
|
||||
val X1 = Register("1", X_REG)
|
||||
val W1 = Register("1", W_REG)
|
||||
val X2 = Register("2", X_REG)
|
||||
val X7 = Register("7", X_REG)
|
||||
val W7 = Register("7", W_REG)
|
||||
val X8 = Register("8", X_REG)
|
||||
val W8 = Register("8", W_REG)
|
||||
val X16 = Register("16", X_REG)
|
||||
val W16 = Register("16", W_REG)
|
||||
val X17 = Register("17", X_REG)
|
||||
val W17 = Register("17", W_REG)
|
||||
|
||||
// Special registers
|
||||
val W30 = Register("30", W_REG)
|
||||
val FP = Register(FP_REG, ' ')
|
||||
val SP = Register(SP_REG, ' ')
|
||||
val LR = Register(LR_REG, ' ')
|
||||
val XZR = Register(XZR_REG, ' ')
|
||||
|
||||
// Immediate values
|
||||
val Imm_1 = Immediate(-1)
|
||||
val Imm0 = Immediate(0)
|
||||
val Imm1 = Immediate(1)
|
||||
val Imm2 = Immediate(2)
|
||||
val Imm3 = Immediate(3)
|
||||
val Imm4 = Immediate(4)
|
||||
val Imm8 = Immediate(8)
|
||||
val Imm16 = Immediate(16)
|
||||
|
||||
|
||||
// Represents an allocated value for code generation
|
||||
sealed trait Allocated {
|
||||
def toOperand: String
|
||||
}
|
||||
|
||||
// Representds an operand value for parsing into instructions (Can only be registers or immediate values)
|
||||
sealed trait Operand {
|
||||
def toOperand: String
|
||||
}
|
||||
|
||||
// Label ops are for strings with the :lo12: prepended to the label
|
||||
case class LabelOp(strLabel: String) extends Operand {
|
||||
override def toOperand: String = s":lo12:$strLabel"
|
||||
}
|
||||
|
||||
// Represents a physical register allocation
|
||||
sealed trait Reg extends Allocated with Operand {
|
||||
|
||||
// Every register should have a register number and a bitmode
|
||||
def regNumber: String
|
||||
def bitMode: Char
|
||||
|
||||
override def toOperand: String = bitMode match {
|
||||
case space if space == ' ' => s"${regNumber}"
|
||||
case _ => s"${bitMode}${regNumber}"
|
||||
}
|
||||
}
|
||||
|
||||
object Reg {
|
||||
|
||||
// Contains the set of allowed bitmodes and register number that a reg can have
|
||||
val allowedBitModes = Set(X_REG, W_REG, ' ')
|
||||
val allowedRegNums = (0 to 31).map(_.toString()).toSet ++ Set(FP_REG, SP_REG, LR_REG, XZR_REG)
|
||||
|
||||
// Each application of a type of register must have the allowed parameters otherwise they can't be created
|
||||
def validate(regNumber: String, bitMode: Char): Unit = {
|
||||
if (!allowedRegNums.contains(regNumber) || !allowedBitModes.contains(bitMode)) {
|
||||
throw IllegalArgumentException("Cannot have a register of such number/bitmode")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A concrete physical register allocated from a specific pool
|
||||
case class PhysicalReg(regNumber: String, bitMode: Char, pool: RegisterPool) extends Reg {
|
||||
Reg.validate(regNumber, bitMode)
|
||||
}
|
||||
case class Register(regNumber: String, bitMode: Char) extends Reg {
|
||||
Reg.validate(regNumber, bitMode)
|
||||
}
|
||||
|
||||
// A class representing an immediate value
|
||||
case class Immediate(value: String) extends Operand {
|
||||
override def toOperand: String = s"#$value"
|
||||
}
|
||||
object Immediate {
|
||||
def apply(value: String): Immediate = new Immediate(value)
|
||||
def apply(value: BigInt): Immediate = new Immediate(s"$value")
|
||||
}
|
||||
|
||||
// A spill slot allocation on the stack
|
||||
// (Can change the basePointer if we want something other than FP)
|
||||
case class Spill(offset: Immediate, basePointer: Reg = FP) extends Allocated {
|
||||
override def toOperand: String = {
|
||||
s"[${basePointer.toOperand}, ${offset.toOperand}]"
|
||||
}
|
||||
}
|
||||
|
||||
// A trait to tag register pools
|
||||
sealed trait RegisterPool {
|
||||
def available: List[String]
|
||||
def name: String
|
||||
}
|
||||
|
||||
// The callee-saved register pool (register numbers as strings)
|
||||
case object CalleePool extends RegisterPool {
|
||||
override val available: List[String] =
|
||||
List("19", "20", "21", "22", "23", "24", "25", "26", "27", "28")
|
||||
override val name: String = "callee"
|
||||
}
|
||||
|
||||
// The temporary register pool (register numbers as strings)
|
||||
case object TempPool extends RegisterPool {
|
||||
override val available: List[String] =
|
||||
List("8", "9", "10", "11", "12", "13", "14", "15")
|
||||
override val name: String = "temp"
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Allocates either Physical registers (x or w bit mode registers)
|
||||
* It maintains two independent pools:
|
||||
* one for callee registers and one for temporary registers.
|
||||
* Allocation functions take a parameter (W_REG or X_REG) to determine the correct register type
|
||||
*/
|
||||
class RegisterAllocator {
|
||||
private var calleeFreeRegNums: List[String] = CalleePool.available
|
||||
private var tempFreeRegNums: List[String] = TempPool.available
|
||||
|
||||
// Map from an identifier to its allocated value
|
||||
val identifierMap = mutable.Map[String, Allocated]()
|
||||
|
||||
// A counter to allocate unique spill slots (each 8 bytes)
|
||||
// Default is the negative of the max num of callee registers (As we do things in reference to fp)
|
||||
private var spillCounter: Int = -1 * MAX_CALLEE_REGS
|
||||
|
||||
private def allocateSpillSlot(): Immediate = {
|
||||
spillCounter -= 1
|
||||
Immediate(spillCounter * EIGHT_BYTES)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an allocation mapping for an identifier
|
||||
* The mapping preserves the exact register type (including its bit mode)
|
||||
*/
|
||||
def addAllocMapping(ident: String, reg: Allocated): Unit = {
|
||||
identifierMap += (ident -> reg)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the allocated value for an identifier
|
||||
*/
|
||||
def getAllocMapping(ident: String): Allocated =
|
||||
identifierMap.getOrElse(ident, throw new RuntimeException(s"Identifier $ident not allocated"))
|
||||
|
||||
/**
|
||||
* Gets the number of variables used in the program //TODO: May have to change this logic a bit later
|
||||
*/
|
||||
def getNumVariables(): BigInt = identifierMap.size
|
||||
|
||||
/**
|
||||
* Allocates a temporary register with the given bit mode (W_REG or X_REG)
|
||||
* Assumption: There is always an available temp register
|
||||
*/
|
||||
def allocateTemp(bitMode: Char): PhysicalReg =
|
||||
tempFreeRegNums match {
|
||||
case regNum :: rest =>
|
||||
tempFreeRegNums = rest
|
||||
PhysicalReg(regNum, bitMode, TempPool)
|
||||
// Just a fall back which shouldn't happen according to our policy
|
||||
case Nil => PhysicalReg("Temp Pool should be available", ' ', TempPool)
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates a callee register with the given bit mode (W_REG or X_REG)
|
||||
* If no registers remain, a spill slot is allocated
|
||||
*/
|
||||
def allocateCallee(bitMode: Char): Allocated =
|
||||
calleeFreeRegNums match {
|
||||
case regNum :: rest =>
|
||||
calleeFreeRegNums = rest
|
||||
PhysicalReg(regNum, bitMode, CalleePool)
|
||||
case Nil =>
|
||||
Spill(allocateSpillSlot())
|
||||
}
|
||||
|
||||
/**
|
||||
* Frees a previously allocated temp register
|
||||
* Callee registers cannot be freed
|
||||
*/
|
||||
def freeTemp(alloc: Allocated): Unit = alloc match {
|
||||
case reg: PhysicalReg => reg.pool match {
|
||||
case TempPool =>
|
||||
if (!tempFreeRegNums.contains(reg.regNumber))
|
||||
tempFreeRegNums = reg.regNumber :: tempFreeRegNums
|
||||
|
||||
// Can only free a temp reg
|
||||
case CalleePool => ()
|
||||
}
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
/**
|
||||
* Restores a temp if the pool isn't full
|
||||
*/
|
||||
def restore(): Unit = {
|
||||
if (tempFreeRegNums.length < TempPool.available.length) {
|
||||
// This is possible due to policy of Temp pool always being available
|
||||
val restoredRegNum = tempFreeRegNums.headOption.getOrElse("16").toInt - 1
|
||||
restoredRegNum match {
|
||||
case _ if restoredRegNum >= 8 => tempFreeRegNums = s"$restoredRegNum" :: tempFreeRegNums
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Frees a previously allocated callee
|
||||
* For physical registers, the register number is returned to the corresponding pool
|
||||
* Spill slots are not reclaimed
|
||||
* Temp registers cannot be freed
|
||||
*/
|
||||
def freeCallee(alloc: Allocated): Unit = alloc match {
|
||||
case reg: PhysicalReg =>
|
||||
reg.pool match {
|
||||
// Can only free a callee reg
|
||||
case TempPool => ()
|
||||
case CalleePool =>
|
||||
if (!calleeFreeRegNums.contains(reg.regNumber))
|
||||
calleeFreeRegNums = reg.regNumber :: calleeFreeRegNums
|
||||
}
|
||||
case _ => // Other allocations are not reclaimed
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the allocator: returns all registers to their pools
|
||||
* resets the spill counter, and clears identifier mappings
|
||||
*/
|
||||
def reset(): Unit = {
|
||||
calleeFreeRegNums = CalleePool.available
|
||||
tempFreeRegNums = TempPool.available
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves all of the register information from before to later be restored
|
||||
*/
|
||||
def saveState(): (List[String], List[String], Int, mutable.Map[String, Allocated]) = {
|
||||
(calleeFreeRegNums, tempFreeRegNums, spillCounter, identifierMap.clone())
|
||||
}
|
||||
|
||||
/**
|
||||
* Restores the state of the allocator
|
||||
*/
|
||||
def restoreState(state: (List[String], List[String], Int, mutable.Map[String, Allocated])): Unit = {
|
||||
val (savedCalleeFreeRegNums, savedTempFreeRegNums, savedSpillCounter, savedIdentifierMap) = state
|
||||
calleeFreeRegNums = savedCalleeFreeRegNums
|
||||
tempFreeRegNums = savedTempFreeRegNums
|
||||
spillCounter = savedSpillCounter
|
||||
identifierMap.clear()
|
||||
identifierMap ++= savedIdentifierMap
|
||||
}
|
||||
}
|
||||
75
src/main/wacc/backend/typeAnalysis.scala
Normal file
75
src/main/wacc/backend/typeAnalysis.scala
Normal file
@ -0,0 +1,75 @@
|
||||
package wacc.backend
|
||||
|
||||
import wacc.frontend.syntax.ast._
|
||||
import wacc.frontend.semantic.environment._
|
||||
|
||||
val IGNORE = (-1, -1)
|
||||
|
||||
/**
|
||||
* Allocates the size of a type
|
||||
*/
|
||||
def sizeof(_type: Type) = _type match {
|
||||
case IntType => 4
|
||||
case (CharType | BoolType) => 1
|
||||
case _ => 8
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the type of an expression
|
||||
*/
|
||||
def exprType(expr: Expr)(using gMainEnv: EnvTrait): Type =
|
||||
expr match {
|
||||
case _: (IntLiter | Neg | Len | Ord | Mul | Div | Mod | Add | Sub) => IntType
|
||||
case _: (BoolLiter | Not | Gt | Geq | Lt | Leq | Eq | Neq | And | Or) => BoolType
|
||||
case _: (CharLiter | Chr) => CharType
|
||||
case _: StrLiter => StrType
|
||||
case _: PairLiter => PairPlaceholder // The null type is a PairPlaceholder
|
||||
case Ident(_, id) =>
|
||||
gMainEnv.lookupType(id).getOrElse(AnyType)
|
||||
case ArrayElem(_, id, indices) =>
|
||||
// Peel off array layers from the type of the id
|
||||
val arrType = exprType(id)
|
||||
val finalType = {
|
||||
indices.foldLeft(arrType) {
|
||||
case (ArrayType(inner), _) => inner
|
||||
case (_, _) => AnyType
|
||||
}
|
||||
}
|
||||
finalType
|
||||
case pairElem: PairElem => lhsType(pairElem)
|
||||
case IfExpr(pos, cond, thenBranch, elseBranch) => exprType(thenBranch)
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the type of a RHS
|
||||
*/
|
||||
def rhsType(rhs: Rhs)
|
||||
(using gMainEnv: EnvTrait, gFuncEnv: EnvTrait): Type = rhs match {
|
||||
case expr: Expr => exprType(expr)(using gMainEnv)
|
||||
case ArrayLiter(_, elems) =>
|
||||
// An array may have at least 1 element where its type is found through the call to exprType
|
||||
ArrayType(exprType(elems.headOption.getOrElse(PairLiter(IGNORE)))(using gMainEnv))
|
||||
case NewPair(_, fst, snd) => (exprType(fst)(using gMainEnv), exprType(snd)(using gMainEnv)) match {
|
||||
case (fType, sType) => PairType(fType, sType)
|
||||
}
|
||||
case Call(_, ident, args) =>
|
||||
gFuncEnv.lookup(ident.id).get._type
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the type of a LHS
|
||||
*/
|
||||
def lhsType(lhs: Lhs)
|
||||
(using gMainEnv: EnvTrait): Type = lhs match {
|
||||
case identifier: (Ident | ArrayElem) => exprType(identifier)
|
||||
case PairElem(_, selector, subLhs) =>
|
||||
val subType = lhsType(subLhs)
|
||||
(selector, subType) match {
|
||||
case (Fst, PairType(PairPlaceholder, _)) => PairPlaceholder
|
||||
case (Snd, PairType(_, PairPlaceholder)) => PairPlaceholder
|
||||
case (Fst, PairType(t:Type, _)) => t
|
||||
case (Snd, PairType(_, t:Type)) => t
|
||||
// This is for any null case, so would be named an AnyType may have TODO a change on this
|
||||
case _ => AnyType
|
||||
}
|
||||
}
|
||||
267
src/main/wacc/extension/peephole.scala
Normal file
267
src/main/wacc/extension/peephole.scala
Normal file
@ -0,0 +1,267 @@
|
||||
package wacc.extension
|
||||
|
||||
import wacc.backend._
|
||||
import scala.collection.mutable
|
||||
|
||||
class Peephole {
|
||||
def optimize(instructions: List[Instruction]): List[Instruction] = {
|
||||
Function.chain(List(removeRedundantLoadandStores, removeRedundantMoves))(instructions)
|
||||
}
|
||||
|
||||
private def removeRedundantLoadandStores(instructions: List[Instruction]): List[Instruction] = {
|
||||
val substitutions = mutable.ListBuffer[(Allocated | Operand, Reg, Int)]()
|
||||
val removalList = mutable.ListBuffer[Int]()
|
||||
|
||||
instructions.zipWithIndex.foreach({
|
||||
// Keep track of stored values, and index of instruction, and remove in situations where the register may be changed
|
||||
case (STP(dest, src, base, _), index) =>
|
||||
substitutions += ((dest, base, index))
|
||||
substitutions += ((src, base, index))
|
||||
|
||||
case (STR(dest, base, _, _), index) => substitutions += ((dest, base, index))
|
||||
case (STRB(dest, base, _), index) => substitutions += ((dest, base, index))
|
||||
case (STUR(dest, base), index) => substitutions += ((dest, base, index))
|
||||
|
||||
// Given a load, check if the register is still stored in the substitutions, if so then the load-store is redundant
|
||||
case (LDP(dest, src, base), index) =>
|
||||
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
|
||||
case Some((_, _, fstIndex)) =>
|
||||
substitutions.find { case (operand, reg, _) => operand == src && reg == base } match {
|
||||
case Some((_, _, sndIndex)) =>
|
||||
removalList += sndIndex; removalList += index
|
||||
case None =>
|
||||
}
|
||||
case None =>
|
||||
}
|
||||
|
||||
case (LDR(dest, base, _, _), index) =>
|
||||
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
|
||||
case Some((_, _, fstIndex)) =>
|
||||
removalList += fstIndex; removalList += index
|
||||
case None =>
|
||||
}
|
||||
|
||||
case (LDRB(dest, base, _), index) =>
|
||||
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
|
||||
case Some((_, _, fstIndex)) =>
|
||||
removalList += fstIndex; removalList += index
|
||||
case None =>
|
||||
}
|
||||
|
||||
case (LDUR(dest, base), index) =>
|
||||
substitutions.find { case (operand, reg, _) => operand == dest && reg == base } match {
|
||||
case Some((_, _, fstIndex)) =>
|
||||
removalList += fstIndex; removalList += index
|
||||
case None =>
|
||||
}
|
||||
// Consider all cases where the register is modified, and remove the stored value
|
||||
case (MoveOps(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (Move(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (ADD(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (SUB(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (SMULL(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (SDIV(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (ADDS(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (SUBS(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (MSUB(dest, _, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (NEGS(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (EOR(dest, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (ADRP(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (ADR(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (CSET(dest, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
case (CSEL(dest, _, _, _), _) => substitutions.filterInPlace(_._1 != dest)
|
||||
|
||||
// Clear the substitutions on branches given any register may be manipulated in the callee
|
||||
case (B(_), _) => substitutions.clear()
|
||||
case (BL(_, _), _) => substitutions.clear()
|
||||
case (BCond(_, _, _), _) => substitutions.clear()
|
||||
case (CBZ(_, _), _) => substitutions.clear()
|
||||
|
||||
case (Label(_), _) => substitutions.clear()
|
||||
|
||||
case _ =>
|
||||
})
|
||||
|
||||
instructions.zipWithIndex.collect {
|
||||
case (instr, index) if !removalList.contains(index) =>
|
||||
instr
|
||||
}
|
||||
}
|
||||
|
||||
// When encountering a move, add the dest and src to a substitutions. if that dest is used in another instruction then substitute the dest with src
|
||||
// given the original src remains unchanged. We then do a second pass that removes the redundant moves.
|
||||
private def removeRedundantMoves(instructions: List[Instruction]): List[Instruction] = {
|
||||
given allocator: RegisterAllocator = new RegisterAllocator()
|
||||
val substitutions = mutable.ListBuffer[((Allocated | Operand, Allocated | Operand), Int)]() // List of (dest, src) to index
|
||||
val substitutedLines = mutable.ListBuffer[((Allocated | Operand, Allocated | Operand), (Int, Int))]() // List of substitutions to their initial and subbed line
|
||||
val removalList = mutable.ListBuffer[Int]() // List of indexes to remove
|
||||
|
||||
// Up until the dest/src in original move has been changed, you can substitute the dest for the src in subsequent instructions
|
||||
def addSubstitution(dest: Allocated, src: List[Allocated | Operand], index: Int): Unit = {
|
||||
src.foreach { source =>
|
||||
substitutions.find(_._1._1 == source) match {
|
||||
case Some(((srcDest, sub), subIndex)) =>
|
||||
substitutedLines += (((srcDest, sub), (index, subIndex)))
|
||||
removalList += subIndex
|
||||
case None =>
|
||||
}
|
||||
}
|
||||
// removes the dest from the substitutions list as both destination and source if it is changed
|
||||
substitutions.filterInPlace(_._1._1 != dest)
|
||||
substitutions.filterInPlace(_._1._2 != dest)
|
||||
}
|
||||
|
||||
// Given a dest, src, and index, substitute the dest with src in the instructions
|
||||
// If operandBool is true, then other types of operands (other than Reg) are allowed in operands 2 onwards
|
||||
def useSubstitution(dest: Reg, src: List[Operand], index: Int, operandBool: Boolean): List[Operand] = {
|
||||
val indexes = substitutedLines.filter(_._2._1 == index).map(_._2)
|
||||
val subs = substitutedLines.filter(_._2._1 == index).map(_._1)
|
||||
val mutableSrc = mutable.ListBuffer(src*)
|
||||
subs.foreach {
|
||||
case (subDest, sub: Reg) =>
|
||||
mutableSrc.zipWithIndex.foreach { case (source, idx) =>
|
||||
if (subDest == source) mutableSrc(idx) = sub
|
||||
}
|
||||
case (subDest, sub: Operand) if operandBool =>
|
||||
mutableSrc.zipWithIndex.foreach { case (source, idx) =>
|
||||
if (idx != 0 && subDest == source) mutableSrc(idx) = sub
|
||||
}
|
||||
case _ =>
|
||||
|
||||
}
|
||||
mutableSrc.zip(src).zip(indexes).foreach { case ((a, b), i) =>
|
||||
if (a == b) {
|
||||
substitutedLines.filterInPlace(_._2 != i)
|
||||
removalList.filterInPlace(_ != i._2)
|
||||
}
|
||||
}
|
||||
mutableSrc.toList
|
||||
}
|
||||
|
||||
// Initial pass, where substitutions are added and instructions are checked to see whether they can be substituted
|
||||
instructions.zipWithIndex.foreach {
|
||||
case (MoveOps(dest, src), index) if dest == src =>
|
||||
removalList += index
|
||||
case (Move(dest, src), index) if dest == src =>
|
||||
removalList += index
|
||||
|
||||
case (MoveOps(dest, src), index) =>
|
||||
src match {
|
||||
// Immediate needs to be explicitly dealt with to avoid overflow substitution
|
||||
case Immediate(value) =>
|
||||
val valInt = value.toInt
|
||||
if (valInt < 0 || valInt >= 0xFFFF) {
|
||||
substitutions.filterInPlace(_._1._1 != dest)
|
||||
substitutions.filterInPlace(_._1._2 != dest)
|
||||
} else {
|
||||
addSubstitution(dest, List(src), index)
|
||||
substitutions += (((dest, src), index))
|
||||
}
|
||||
case _ =>
|
||||
addSubstitution(dest, List(src), index)
|
||||
substitutions += (((dest, src), index))
|
||||
}
|
||||
|
||||
case (Move(dest, src), index) => addSubstitution(dest, List(src), index)
|
||||
|
||||
case (MSUB(dest, src1, src2, src3), index) => substitutions.clear()
|
||||
case (CSEL(dest, src1, src2, src3), index) => substitutions.clear()
|
||||
|
||||
case (ADD(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
case (ADDS(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
case (SUB(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
case (SUBS(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
case (SMULL(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
case (SDIV(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
case (EOR(dest, src1, src2), index) => addSubstitution(dest, List(src1, src2), index)
|
||||
|
||||
case (NEGS(dest, src), index) => addSubstitution(dest, List(src), index)
|
||||
|
||||
// Although these instructions do not use operand registers, addSubsitutions is still used to clear substitutions based on destination
|
||||
case (ADRP(dest, src), index) => addSubstitution(dest, List(), index)
|
||||
case (ADR(dest, src), index) => addSubstitution(dest, List(), index)
|
||||
case (CSET(dest, src), index) => addSubstitution(dest, List(), index)
|
||||
|
||||
case (Compare(_, _, _), _) =>
|
||||
case (CMP(_, _, _), _) =>
|
||||
case (TST(_, _), _) =>
|
||||
case (Comment(_, _), _) =>
|
||||
case (Directive(_, _), _) =>
|
||||
case (EOF(), _) =>
|
||||
case (RET(), _) =>
|
||||
|
||||
// Clear substitutions on any instruction that hasnt been explicitly handled to maintain consistency.
|
||||
case x => substitutions.clear()
|
||||
}
|
||||
|
||||
// Second pass, where instructions are substituted given they were added in the first pass
|
||||
val substituted = instructions.zipWithIndex.collect {
|
||||
case (MoveOps(dest, src), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val indexes = substitutedLines.find(_._2._1 == index).get._2
|
||||
val sub = substitutedLines.find(_._2._1 == index).get._1._2
|
||||
sub match {
|
||||
case x : Operand => MoveOps(dest, x)
|
||||
case _ =>
|
||||
substitutedLines.filterInPlace(_._2 != indexes)
|
||||
removalList.filterInPlace(_ != indexes._2)
|
||||
MoveOps(dest, src)
|
||||
}
|
||||
case (Move(dest, src), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val indexes = substitutedLines.find(_._2._1 == index).get._2
|
||||
val sub = substitutedLines.find(_._2._1 == index).get._1._2
|
||||
sub match {
|
||||
case x : Allocated => Move(dest, x)
|
||||
case x : Operand =>
|
||||
dest match {
|
||||
case y : Reg => MoveOps(y, x)
|
||||
case _ =>
|
||||
substitutedLines.filterInPlace(_._2 != indexes)
|
||||
removalList.filterInPlace(_ != indexes._2)
|
||||
Move(dest, src)
|
||||
}
|
||||
}
|
||||
|
||||
// We call useSubstitution for generic cases on all operands, and reconstruct instruction with substituted operands.
|
||||
case (ADD(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, true)
|
||||
ADD(dest, srcs(0), srcs(1))
|
||||
|
||||
case (ADDS(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, true)
|
||||
ADDS(dest, srcs(0), srcs(1))
|
||||
|
||||
case (SUB(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, true)
|
||||
SUB(dest, srcs(0), srcs(1))
|
||||
|
||||
case (SUBS(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, true)
|
||||
SUBS(dest, srcs(0), srcs(1))
|
||||
|
||||
case (SMULL(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, false)
|
||||
SMULL(dest, srcs(0), srcs(1))
|
||||
|
||||
case (SDIV(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, false)
|
||||
SDIV(dest, srcs(0), srcs(1))
|
||||
|
||||
case (EOR(dest, src1, src2), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src1, src2), index, true)
|
||||
EOR(dest, srcs(0), srcs(1))
|
||||
|
||||
case (NEGS(dest, src), index) if substitutedLines.exists(_._2._1 == index) =>
|
||||
val srcs = useSubstitution(dest, List(src), index, false)
|
||||
NEGS(dest, srcs(0))
|
||||
|
||||
case (instr, index) =>
|
||||
instr
|
||||
}
|
||||
|
||||
// Finally, return the instructions with the redundant moves removed.
|
||||
substituted.zipWithIndex.collect {
|
||||
case (instr, index) if !removalList.contains(index) =>
|
||||
instr
|
||||
}
|
||||
}
|
||||
}
|
||||
185
src/main/wacc/frontend/semantic/environment.scala
Normal file
185
src/main/wacc/frontend/semantic/environment.scala
Normal file
@ -0,0 +1,185 @@
|
||||
package wacc.frontend.semantic
|
||||
|
||||
import wacc.frontend.syntax.ast._
|
||||
import wacc.backend.exprType
|
||||
import scala.collection.mutable
|
||||
|
||||
object environment {
|
||||
trait EnvTrait {
|
||||
def lookup(name: String): Option[Symbol]
|
||||
def lookupType(id: String): Option[Type]
|
||||
}
|
||||
|
||||
// An environment for variable (and parameter) declarations
|
||||
class Env(val parent: Option[Env]) extends EnvTrait {
|
||||
val vars: mutable.Map[String, VarSymbol] = mutable.Map.empty
|
||||
|
||||
override def lookup(name: String): Option[VarSymbol] = vars.get(name) match {
|
||||
case None => parent.flatMap(_.lookup(name))
|
||||
case some => some
|
||||
}
|
||||
|
||||
override def lookupType(id: String): Option[Type] = {
|
||||
|
||||
// For environments, we check on their raw input, not on the hashed value
|
||||
val idKey = id.replaceAll("#\\d+$", "")
|
||||
lookup(idKey) match {
|
||||
case None => None
|
||||
case Some(VarSymbol(_, _, _type, _)) => Some(_type)
|
||||
}
|
||||
}
|
||||
|
||||
def add(sym: VarSymbol)(using errors: mutable.ListBuffer[SemanticError]): Unit = {
|
||||
if (!vars.contains(sym.id)) {
|
||||
vars(sym.id) = sym
|
||||
} else {
|
||||
// Illegal declaration
|
||||
errors += RedeclarationError(sym.pos, sym.id, inFunction = false)(using this)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Global environment for functions
|
||||
class GlobalFuncEnv extends EnvTrait {
|
||||
|
||||
// A map from the base identifier (not renamed) to a set of Tuples of
|
||||
// n-arity of Types
|
||||
val funcPoolMap: mutable.Map[String, mutable.ListBuffer[(List[Type], String)]] = mutable.Map.empty
|
||||
|
||||
// A map from the unique identifier to its func symbol
|
||||
val funcs: mutable.Map[String, FuncSymbol] = mutable.Map.empty
|
||||
|
||||
// Looks up the function and may return its symbol
|
||||
override def lookup(id: String): Option[FuncSymbol] = funcs.get(id)
|
||||
|
||||
override def lookupType(id: String): Option[Type] = funcs.get(id) match {
|
||||
case None => None
|
||||
case Some(FuncSymbol(_, _, _type, _, _)) => Some(_type)
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs lookup for accessing the correct unique symbol for calls
|
||||
* Does a rudimentary type analysis to match to the correct function (we don't actually check if the typing is correct)
|
||||
*/
|
||||
def callLookup(id: String, args: List[Expr], pos: (Int, Int), localEnv: Env)
|
||||
(using errors: mutable.ListBuffer[SemanticError]): Option[FuncSymbol] = {
|
||||
|
||||
// Create a list of n-arity of Types where n is the length of args
|
||||
val argSignature = args.map(exprType(_)(using localEnv))
|
||||
|
||||
// Get the number of matched functions and the last matched uniqueIds
|
||||
val (matchedFunctions, uniqueId) = funcPoolMap.get(id) match {
|
||||
case None => (0, "")
|
||||
case Some(funcSignatures) => countMatches(funcSignatures.toList, argSignature)
|
||||
}
|
||||
|
||||
// Match on the number of functions matched
|
||||
// (Greater than 1 means ambiguity)
|
||||
matchedFunctions match {
|
||||
case 0 =>
|
||||
// Undefined Function
|
||||
errors += UndefinedFuncError(pos, id)
|
||||
None
|
||||
|
||||
// Finally, get the uniqueId
|
||||
case 1 => funcs.get(uniqueId)
|
||||
|
||||
case _ =>
|
||||
// Ambiguous Function call
|
||||
errors += AmbiguousFuncCallError(pos, id)
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Potentially adds a FuncSymbol to the funcs map but first checks if its parameters exist in
|
||||
// the pool map. If they do, then we add a redeclaration error
|
||||
def add(funcSymbol: FuncSymbol)(using errors: mutable.ListBuffer[SemanticError]): Unit = {
|
||||
|
||||
// Extract needed parts of funcSymbol
|
||||
val FuncSymbol(id, uniqueId, _, paramSymbs, pos) = funcSymbol
|
||||
|
||||
// Create a list of n-arity of Types where n is the length of parameters
|
||||
val paramTypes = getParamTypes(paramSymbs)
|
||||
|
||||
// Find the list of function signatures from the funcPoolMap
|
||||
val funcSignatures = funcPoolMap.getOrElse(id, mutable.ListBuffer.empty)
|
||||
|
||||
|
||||
val (matchedFunctions, _) = countMatches(funcSignatures.toList, paramTypes)
|
||||
|
||||
matchedFunctions match {
|
||||
|
||||
// We now check the uniqueness of paramTypes
|
||||
case 0 =>
|
||||
|
||||
// Create the tuple of paramTypes with the unique id of the function
|
||||
val typeWithId = (paramTypes, uniqueId)
|
||||
funcPoolMap.getOrElseUpdate(id, mutable.ListBuffer.empty) += typeWithId
|
||||
funcs += (uniqueId -> funcSymbol)
|
||||
|
||||
// Illegal declaration
|
||||
case _ =>
|
||||
errors += RedeclarationError(pos, id, inFunction = true, paramTypes)(using this)
|
||||
funcs += (id -> funcSymbol)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the number of matches that an argSignature makes with the funcSignatures that are known to the program thus far
|
||||
* Returns the number of matches and the last seen unique id
|
||||
*/
|
||||
def countMatches(funcSignatures: List[(List[Type], String)], argSignature: List[Type]): (Int, String) = {
|
||||
|
||||
// Get the candidate functions with the same arity as the called function
|
||||
val matchCandidates = funcSignatures.filter(_._1.size == argSignature.size)
|
||||
var matchedFunctions = 0
|
||||
var uniqueId = ""
|
||||
|
||||
matchCandidates.foreach { candidate =>
|
||||
|
||||
// For each candidate, we check if the signatures match
|
||||
val (paramSignature, uId) = candidate
|
||||
val isMatch = argSignature.zip(paramSignature).foldLeft(true) { (acc, zippedSignature) =>
|
||||
|
||||
// Extract the arg and paramter type, then see if the argument type casts to the parameter
|
||||
val (argType, paramType) = zippedSignature
|
||||
(argType typecast Some(paramType)) && acc
|
||||
}
|
||||
|
||||
// We update the functions matched and unique id
|
||||
if (isMatch) {
|
||||
matchedFunctions += 1
|
||||
uniqueId = uId
|
||||
}
|
||||
}
|
||||
(matchedFunctions, uniqueId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a list of the types that the parameter has
|
||||
*/
|
||||
private def getParamTypes(paramSymbs: List[VarSymbol]): List[Type] = paramSymbs.map { paramSym =>
|
||||
|
||||
// Extract the paramType from paramSym
|
||||
val VarSymbol(_, _, paramType, _) = paramSym
|
||||
|
||||
// Output the paramType
|
||||
paramType
|
||||
}
|
||||
}
|
||||
|
||||
// Global environment for main body
|
||||
class GlobalMainEnv extends EnvTrait {
|
||||
val vars: mutable.Map[String, VarSymbol] = mutable.Map.empty
|
||||
|
||||
def lookup(id: String): Option[Symbol] = vars.get(id)
|
||||
override def lookupType(id: String): Option[Type] = vars.get(id) match {
|
||||
case Some(VarSymbol(_, _, _type, _)) => Some(_type)
|
||||
case None => None
|
||||
}
|
||||
|
||||
def add(_var: VarSymbol)(using errors: mutable.ListBuffer[SemanticError] = mutable.ListBuffer.empty): Unit = {
|
||||
vars += (_var.uniqueName -> _var)
|
||||
}
|
||||
}
|
||||
}
|
||||
274
src/main/wacc/frontend/semantic/renamer.scala
Normal file
274
src/main/wacc/frontend/semantic/renamer.scala
Normal file
@ -0,0 +1,274 @@
|
||||
package wacc.frontend.semantic
|
||||
|
||||
import scala.collection.mutable
|
||||
import wacc.frontend.semantic.environment._
|
||||
import wacc.frontend.syntax.ast._
|
||||
|
||||
class Renamer {
|
||||
|
||||
// A counter to generate unique names
|
||||
private var counter: Int = 0
|
||||
|
||||
private def generateUnique(name: String): String = {
|
||||
counter += 1
|
||||
s"$name#$counter"
|
||||
}
|
||||
|
||||
// Renames the entire program, returning a new WProgram with uniquely renamed identifiers
|
||||
// It first updates the global function environment, then renames the main block and each function
|
||||
def renameProgram(program: WProgram):
|
||||
(WProgram, mutable.ListBuffer[SemanticError], GlobalFuncEnv, GlobalMainEnv) = {
|
||||
|
||||
// Mutable list of errors found while traversing AST for renaming
|
||||
// (Would be populated with scope errors for now)
|
||||
given errors: mutable.ListBuffer[SemanticError] = mutable.ListBuffer.empty
|
||||
|
||||
// Global Environments for both Functions and the Main body
|
||||
given gFuncEnv: GlobalFuncEnv = new GlobalFuncEnv
|
||||
given gMainEnv: GlobalMainEnv = new GlobalMainEnv
|
||||
|
||||
// Create a new environment for starting a block block
|
||||
{
|
||||
given localEnv: Env = new Env(None)
|
||||
given inFunction: Boolean = false
|
||||
|
||||
// Rename each function.
|
||||
val newFuncs = renameFuncs(program.funcs)
|
||||
|
||||
// Rename main block statements.
|
||||
val newStats = program.stats.map(renameStmt)
|
||||
|
||||
// We'll keep the errors found so far and a new AST with renamed nodes
|
||||
(WProgram(program.pos, program.imports, newFuncs, newStats), errors, gFuncEnv, gMainEnv)
|
||||
}
|
||||
}
|
||||
|
||||
/// Renames all function definitions
|
||||
def renameFuncs(funcs: List[Func])
|
||||
(using errors: mutable.ListBuffer[SemanticError],
|
||||
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): List[Func] = {
|
||||
|
||||
// Contains a funcSymbol with a list of its parameters, body, and associated environment
|
||||
val fGroups: List[(FuncSymbol, List[Param], List[Stmt], Env)] = funcs.map { f =>
|
||||
|
||||
// Gen a unique name for the function
|
||||
val uniqueFuncName = generateUnique(f.id.id)
|
||||
|
||||
// Get the param symbol and param lists at the same time to feed into
|
||||
// making a FuncSymbol and Func
|
||||
val (paramSymbols, newParams) = f.args.map { p =>
|
||||
val uniqueParamName = generateUnique(p.id.id)
|
||||
val newIdent = Ident(p.id.pos, uniqueParamName)
|
||||
(VarSymbol(p.id.id, uniqueParamName, p.t, p.pos), Param(p.pos, p.t, newIdent))
|
||||
}.unzip
|
||||
|
||||
// Create a funcSym using the paramSymbols and uniqueFuncName created and then add to
|
||||
// GlobalFuncEnv
|
||||
val funcSym = FuncSymbol(f.id.id, uniqueFuncName, f.t, paramSymbols, f.id.pos)
|
||||
gFuncEnv.add(funcSym)
|
||||
|
||||
// Create an environment to add the paramSymbols
|
||||
val funcEnv = new Env(None)
|
||||
paramSymbols.map(funcEnv.add)
|
||||
|
||||
(funcSym, newParams, f.body, funcEnv)
|
||||
}
|
||||
|
||||
val newFuncs = fGroups.map { fGroup =>
|
||||
val (funcSym, newParams, body, funcEnv) = fGroup
|
||||
|
||||
given localEnv: Env = new Env(Some(funcEnv))
|
||||
given inFunction: Boolean = true
|
||||
|
||||
// Create the new identity and body
|
||||
val newIdent = Ident(funcSym.pos, funcSym.uniqueName)
|
||||
val newBody = body.map((stmt: Stmt) => renameStmt(stmt))
|
||||
|
||||
Func(funcSym.pos, funcSym._type, newIdent, newParams, newBody)
|
||||
}
|
||||
newFuncs
|
||||
}
|
||||
|
||||
// Rename the elements within statements while checking the validity of identifiers
|
||||
def renameStmt(stmt: Stmt)
|
||||
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
|
||||
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Stmt = {
|
||||
|
||||
stmt match {
|
||||
case Skip(pos) => Skip(pos)
|
||||
case Assign(pos, lhs, rhs) => Assign(pos, renameLhs(lhs), renameRhs(rhs))
|
||||
case Read(pos, lhs) => Read(pos, renameLhs(lhs))
|
||||
case Free(pos, expr) => Free(pos, renameExpr(expr))
|
||||
case Return(pos, expr) => Return(pos, renameExpr(expr))
|
||||
case Exit(pos, expr) => Exit(pos, renameExpr(expr))
|
||||
case Print(pos, expr) => Print(pos, renameExpr(expr))
|
||||
case Println(pos, expr) => Println(pos, renameExpr(expr))
|
||||
|
||||
// Add a new symbol to the environment on declaration
|
||||
case Declare(pos, t, ident, rhs) => {
|
||||
val newRhs = renameRhs(rhs)
|
||||
val uniqueName = generateUnique(ident.id)
|
||||
val newIdent = Ident(ident.pos, uniqueName)
|
||||
val varSym = VarSymbol(ident.id, uniqueName, t, ident.pos)
|
||||
|
||||
// Add the new symbol to the local and global variable environment
|
||||
localEnv.add(varSym)
|
||||
gMainEnv.add(varSym)
|
||||
Declare(pos, t, newIdent, newRhs)
|
||||
}
|
||||
|
||||
case If(pos, cond, thenCase, elseCase) => {
|
||||
val newCond = renameExpr(cond)
|
||||
// Create new environments for each branch.
|
||||
val thenEnv = new Env(Some(localEnv))
|
||||
val newThen = {
|
||||
given localEnv: Env = thenEnv
|
||||
thenCase.map(renameStmt)
|
||||
}
|
||||
val elseEnv = new Env(Some(localEnv))
|
||||
val newElse = {
|
||||
given localEnv: Env = elseEnv
|
||||
elseCase.map(renameStmt)
|
||||
}
|
||||
If(pos, newCond, newThen, newElse)
|
||||
}
|
||||
|
||||
case While(pos, cond, body) => {
|
||||
val newCond = renameExpr(cond)
|
||||
val bodyEnv = new Env(Some(localEnv))
|
||||
val newBody = {
|
||||
given localEnv: Env = bodyEnv
|
||||
body.map(renameStmt)
|
||||
}
|
||||
While(pos, newCond, newBody)
|
||||
}
|
||||
|
||||
case Block(pos, stmts) => {
|
||||
val blockEnv = new Env(Some(localEnv))
|
||||
val newStmts = {
|
||||
given localEnv: Env = blockEnv
|
||||
stmts.map(renameStmt)
|
||||
}
|
||||
Block(pos, newStmts)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Renames all expressions which may contain identities and ignores literals
|
||||
def renameExpr(expr: Expr)
|
||||
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
|
||||
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Expr = {
|
||||
expr match {
|
||||
case Not(x) => Not(renameExpr(x))
|
||||
case Neg(x) => Neg(renameExpr(x))
|
||||
case Len(x) => Len(renameExpr(x))
|
||||
case Ord(x) => Ord(renameExpr(x))
|
||||
case Chr(x) => Chr(renameExpr(x))
|
||||
|
||||
case Mul(x, y) => Mul(renameExpr(x), renameExpr(y))
|
||||
case Div(x, y) => Div(renameExpr(x), renameExpr(y))
|
||||
case Mod(x, y) => Mod(renameExpr(x), renameExpr(y))
|
||||
|
||||
case Add(x, y) => Add(renameExpr(x), renameExpr(y))
|
||||
case Sub(x, y) => Sub(renameExpr(x), renameExpr(y))
|
||||
|
||||
case Gt(x, y) => Gt(renameExpr(x), renameExpr(y))
|
||||
case Geq(x, y) => Geq(renameExpr(x), renameExpr(y))
|
||||
case Lt(x, y) => Lt(renameExpr(x), renameExpr(y))
|
||||
case Leq(x, y) => Leq(renameExpr(x), renameExpr(y))
|
||||
|
||||
case Eq(x, y) => Eq(renameExpr(x), renameExpr(y))
|
||||
case Neq(x, y) => Neq(renameExpr(x), renameExpr(y))
|
||||
|
||||
case And(x, y) => And(renameExpr(x), renameExpr(y))
|
||||
case Or(x, y) => Or(renameExpr(x), renameExpr(y))
|
||||
|
||||
case ident: Ident => localEnv.lookup(ident.id) match {
|
||||
case Some(varSym) => Ident(ident.pos, varSym.uniqueName)
|
||||
case None => {
|
||||
// Undeclared Variable
|
||||
errors += UndefinedValError(ident.pos, ident.id)
|
||||
ident
|
||||
}
|
||||
}
|
||||
|
||||
case ArrayElem(pos, id, indices) => localEnv.lookup(id.id) match {
|
||||
case Some(varSym) =>
|
||||
ArrayElem(pos, Ident(id.pos, varSym.uniqueName), indices.map(renameExpr))
|
||||
case None => {
|
||||
// Undeclared Array
|
||||
errors += UndefinedValError(id.pos, id.id)
|
||||
ArrayElem(pos, id, indices.map(renameExpr))
|
||||
}
|
||||
}
|
||||
|
||||
case IfExpr(pos, cond, thenBranch, elseBranch) => {
|
||||
val newCond = renameExpr(cond)
|
||||
val thenEnv = new Env(Some(localEnv))
|
||||
val newThen = {
|
||||
given env: Env = thenEnv
|
||||
renameExpr(thenBranch)
|
||||
}
|
||||
val elseEnv = new Env(Some(localEnv))
|
||||
val newElse = {
|
||||
given env: Env = elseEnv
|
||||
renameExpr(elseBranch)
|
||||
}
|
||||
IfExpr(pos, newCond, newThen, newElse)
|
||||
}
|
||||
|
||||
// The rest are literals so no identities to rename here
|
||||
case rest => rest
|
||||
}
|
||||
}
|
||||
|
||||
def renameRhs(rhs: Rhs)
|
||||
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
|
||||
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Rhs = {
|
||||
rhs match {
|
||||
case ArrayLiter(pos, elems) => ArrayLiter(pos, elems.map(renameExpr))
|
||||
case NewPair(pos, fst, snd) => NewPair(pos, renameExpr(fst), renameExpr(snd))
|
||||
case Call(pos, id, args) => {
|
||||
val newArgs = args.map(renameExpr)
|
||||
val funcSym = gFuncEnv.callLookup(id.id, newArgs, pos, localEnv).getOrElse {
|
||||
FuncSymbol(id.id, generateUnique(id.id), AnyType, Nil, id.pos)
|
||||
}
|
||||
val newId = Ident(id.pos, funcSym.uniqueName)
|
||||
Call(pos, newId, newArgs)
|
||||
}
|
||||
|
||||
case PairElem(pos, selector, lhs) => PairElem(pos, selector, renameLhs(lhs))
|
||||
|
||||
case expr: Expr => renameExpr(expr)
|
||||
}
|
||||
}
|
||||
|
||||
def renameLhs(lhs: Lhs)
|
||||
(using localEnv: Env, inFunction: Boolean, errors: mutable.ListBuffer[SemanticError],
|
||||
gFuncEnv: GlobalFuncEnv, gMainEnv: GlobalMainEnv): Lhs = {
|
||||
lhs match {
|
||||
case ident: Ident => localEnv.lookup(ident.id) match {
|
||||
case Some(varSym) => Ident(ident.pos, varSym.uniqueName)
|
||||
case None => {
|
||||
// Undeclared Variable
|
||||
errors += UndefinedValError(ident.pos, ident.id)
|
||||
ident
|
||||
}
|
||||
}
|
||||
|
||||
case ArrayElem(pos, id, indices) => localEnv.lookup(id.id) match {
|
||||
case Some(varSym) =>
|
||||
ArrayElem(pos, Ident(id.pos, varSym.uniqueName), indices.map(renameExpr))
|
||||
case None => {
|
||||
// Undeclared Array
|
||||
errors += UndefinedValError(id.pos, id.id)
|
||||
ArrayElem(pos, id, indices.map(renameExpr))
|
||||
}
|
||||
}
|
||||
|
||||
case PairElem(pos, selector, lhs) => PairElem(pos, selector, renameLhs(lhs))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
18
src/main/wacc/frontend/semantic/semanticAnalyser.scala
Normal file
18
src/main/wacc/frontend/semantic/semanticAnalyser.scala
Normal file
@ -0,0 +1,18 @@
|
||||
package wacc.frontend.semantic
|
||||
|
||||
import wacc.frontend.semantic.environment.GlobalFuncEnv
|
||||
import wacc.frontend.semantic.environment.GlobalMainEnv
|
||||
import wacc.frontend.syntax.ast.WProgram
|
||||
|
||||
def analyse(program: WProgram):
|
||||
Either[(WProgram, GlobalFuncEnv, GlobalMainEnv), List[SemanticError]] = {
|
||||
|
||||
val renamer = new Renamer()
|
||||
val (renamedAST, errors, globalFuncEnv, globalMainEnv) = renamer.renameProgram(program)
|
||||
val (renamedProgram, errs) =
|
||||
typeChecking.validateAST(renamedAST, globalFuncEnv.funcs.toMap)(using globalMainEnv, errors)
|
||||
errs match {
|
||||
case Nil => Left((renamedProgram, globalFuncEnv, globalMainEnv))
|
||||
case _ => Right(errs)
|
||||
}
|
||||
}
|
||||
156
src/main/wacc/frontend/semantic/semanticErrors.scala
Normal file
156
src/main/wacc/frontend/semantic/semanticErrors.scala
Normal file
@ -0,0 +1,156 @@
|
||||
package wacc.frontend.semantic
|
||||
|
||||
import scala.io.Source
|
||||
import wacc.frontend.semantic.environment._
|
||||
import wacc.frontend.syntax.ast.Type
|
||||
|
||||
val SEMANTIC_ERROR_CODE = 200
|
||||
|
||||
sealed trait SemanticError {
|
||||
def getPosition: (Int, Int)
|
||||
def getMessage(fileName: String, path: String): String
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
// ERRORS FOUND DURING SCOPE ANALYSIS
|
||||
// ====================================================
|
||||
|
||||
case class UndefinedFuncError(pos: (Int, Int), id: String) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Undefined function error in $fileName" +
|
||||
s" (${pos._1},${pos._2}):\n function $id is undefined:" +
|
||||
s"\n$line\n"
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: May add on the number of matches
|
||||
case class AmbiguousFuncCallError(pos: (Int, Int), id: String) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Ambiguous function call in $fileName" +
|
||||
s" (${pos._1},${pos._2}):\n function call for $id is ambiguous:" +
|
||||
s"\n$line\n"
|
||||
}
|
||||
}
|
||||
|
||||
case class UndefinedValError(pos: (Int, Int), id: String)(using env: Env)
|
||||
extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
val unpackedEnv = unpackEnv(env)
|
||||
val relevantInScope = s" relevant in-scope variables include:\n${unpackEnv(env)}\n | $line\n"
|
||||
s"Scope error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" variable $id has not been declared in this scope:\n" +
|
||||
s"${if (unpackedEnv.size > 3) then relevantInScope else " there are no relevant in-scope variables\n"}"
|
||||
}
|
||||
|
||||
// Return a list of `\n` separated declarations, each line `_type _ident`.
|
||||
def unpackEnv(e: Env): String = gatherVars(e)
|
||||
.map(sym => s" ${sym._type.getType} ${sym.id} (declared on line ${sym.pos._1})")
|
||||
.mkString("\n")
|
||||
|
||||
// Recursively collect all visible symbols from `env` up to outer parents.
|
||||
def gatherVars(e: Env): List[VarSymbol] = {
|
||||
e.vars.values.toList ++
|
||||
(e.parent match {
|
||||
case Some(p) => gatherVars(p)
|
||||
case None => Nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case class RedeclarationError(pos: (Int, Int), id: String, inFunction: Boolean, paramTypes: List[Type] = Nil)
|
||||
(using env: EnvTrait) extends SemanticError {
|
||||
|
||||
override def getPosition: (Int, Int) = pos
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
if (inFunction) {
|
||||
s"Function redefinition error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" illegal redefinition of function $id with parameters (${paramTypes.map(_.getType).mkString(", ")})\n" +
|
||||
s" previously declared on " +
|
||||
s"line ${env.lookup(id).get.pos._1}\n | $line\n"
|
||||
} else {
|
||||
s"Scope error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" illegal redeclaration of variable $id\n" +
|
||||
s" previously declared (in this scope) on " +
|
||||
s"line ${env.lookup(id).get.pos._1}\n | $line\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
// ERRORS FOUND DURING TYPE CHECKING
|
||||
// ====================================================
|
||||
|
||||
case class InvalidTypeError(pos: (Int, Int), details: (String, String)) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Type error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" unexpected ${details._1}\n" +
|
||||
s" expected ${details._2}\n" +
|
||||
s" | $line\n"
|
||||
}
|
||||
}
|
||||
|
||||
case class PairAssignmentError(pos: (Int, Int), details: (String, String)) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Type error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" ${details._2}\n" +
|
||||
s" | $line\n"
|
||||
}
|
||||
}
|
||||
|
||||
case class InvalidConditionError(pos: (Int, Int), details: (String, String)) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Condition error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" unexpected ${details._1}\n" +
|
||||
s" ${details._2}\n" +
|
||||
s" | $line\n"
|
||||
}
|
||||
}
|
||||
|
||||
case class FunctionCallError(pos: (Int, Int), details: (String, String)) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Function call error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" ${details._2}\n" +
|
||||
s" | $line\n"
|
||||
}
|
||||
}
|
||||
|
||||
case class GeneralTypeError(pos: (Int, Int), details: (String, String)) extends SemanticError {
|
||||
override def getPosition: (Int, Int) = pos
|
||||
override def getMessage(fileName: String, path: String): String = {
|
||||
val line = getLineFromFile(path, pos._1).getOrElse("<Unknown Line>").strip()
|
||||
s"Type error in $fileName (${pos._1},${pos._2}):\n" +
|
||||
s" unexpected ${details._1}\n" +
|
||||
s" ${details._2}\n" +
|
||||
s" | $line\n"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private def getLineFromFile(path: String, lineNumber: Int): Option[String] = {
|
||||
try {
|
||||
val source = Source.fromFile(path)
|
||||
val line = source.getLines().drop(lineNumber - 1).nextOption() // Get the specific line
|
||||
source.close()
|
||||
line
|
||||
} catch {
|
||||
case _: Exception => None // Handle errors if file can't be read
|
||||
}
|
||||
}
|
||||
20
src/main/wacc/frontend/semantic/symbols.scala
Normal file
20
src/main/wacc/frontend/semantic/symbols.scala
Normal file
@ -0,0 +1,20 @@
|
||||
package wacc.frontend.semantic
|
||||
|
||||
import wacc.frontend.syntax.ast._
|
||||
|
||||
// Base trait for a resolved symbol.
|
||||
sealed trait Symbol {
|
||||
val id: String
|
||||
val _type: Type
|
||||
val pos: (Int, Int)
|
||||
}
|
||||
|
||||
// A symbol representing a variable (or function parameter).
|
||||
// The field `uniqueName` holds the new (alpha–converted) name.
|
||||
case class VarSymbol(id: String, uniqueName: String, _type: Type, pos: (Int, Int))
|
||||
extends Symbol
|
||||
|
||||
// A symbol representing a function.
|
||||
// The field `uniqueName` holds the new (alpha–converted) name.
|
||||
case class FuncSymbol(id: String, uniqueName: String, _type: Type,
|
||||
params: List[VarSymbol], pos: (Int, Int)) extends Symbol
|
||||
293
src/main/wacc/frontend/semantic/typeChecking.scala
Normal file
293
src/main/wacc/frontend/semantic/typeChecking.scala
Normal file
@ -0,0 +1,293 @@
|
||||
package wacc.frontend.semantic
|
||||
|
||||
import wacc.frontend.semantic.environment.GlobalMainEnv
|
||||
import wacc.frontend.syntax.ast._
|
||||
import scala.collection.mutable
|
||||
|
||||
object typeChecking {
|
||||
|
||||
// Chooses the correct error type
|
||||
private def addSemanticError(pos: (Int, Int), offending: String, message: String,
|
||||
errConstructor: ( (Int, Int), (String, String) ) => SemanticError)
|
||||
(using errors: mutable.ListBuffer[SemanticError]): Unit =
|
||||
errors += errConstructor(pos, (offending, message))
|
||||
|
||||
def validateAST(program: WProgram, fMap: Map[String, FuncSymbol])
|
||||
(using globalMainEnv: GlobalMainEnv, errors: mutable.ListBuffer[SemanticError]):
|
||||
(WProgram, List[SemanticError]) = {
|
||||
|
||||
given funcMap: Map[String, FuncSymbol] = fMap
|
||||
|
||||
val WProgram(pos, imports, funcs, stmts) = program
|
||||
funcs.foreach(validateFunc)
|
||||
stmts.foreach(validateStmt)
|
||||
|
||||
(program, errors.sortBy(semErr => semErr.getPosition).toList)
|
||||
}
|
||||
|
||||
def validateFunc(func: Func)
|
||||
(using funcMap: Map[String, FuncSymbol], globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Unit = {
|
||||
|
||||
val Func(pos, retType, ident, params, stmts) = func
|
||||
stmts.foreach(validateStmt(_)(using funcMap, Some(retType), params))
|
||||
}
|
||||
|
||||
def validateStmt(stmt: Stmt)
|
||||
(using funcMap: Map[String, FuncSymbol], funcReturnType: Option[Type] = None,
|
||||
paramList: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Unit = stmt match {
|
||||
case Skip(_) => ()
|
||||
case Print(pos, expr) => findExprType(pos, expr)
|
||||
case Println(pos, expr) => findExprType(pos, expr)
|
||||
|
||||
case Declare(pos, declType, ident, rhs) =>
|
||||
findRhsType(pos, rhs).foreach { rhsType =>
|
||||
if rhsType != PairPlaceholder && !(rhsType typecast Some(declType)) then
|
||||
addSemanticError(pos, s"${declType.getType}", s"${rhsType.getType}", InvalidTypeError.apply)
|
||||
}
|
||||
|
||||
case Assign(pos, lhs, rhs) => (findLhsType(pos, lhs), findRhsType(pos, rhs)) match {
|
||||
case (Some(PairPlaceholder), Some(PairPlaceholder)) =>
|
||||
addSemanticError(pos, s"", "Cannot assign unknown types (both PairPlaceholder)",
|
||||
PairAssignmentError.apply)
|
||||
|
||||
case (Some(lhsType), Some(rhsType)) =>
|
||||
// highlight: unify the “check if both AnyType” into a single guard
|
||||
if lhsType == AnyType && rhsType == AnyType then
|
||||
addSemanticError(pos, s"${lhsType.getType}", "Cannot assign AnyType to AnyType", GeneralTypeError.apply)
|
||||
|
||||
// highlight: more concise checks for PairPlaceholder and PairType
|
||||
val lhsIsPartialPair =
|
||||
lhsType == PairPlaceholder && rhsType.isInstanceOf[PairType]
|
||||
val rhsIsPartialPair =
|
||||
rhsType == PairPlaceholder && lhsType.isInstanceOf[PairType]
|
||||
|
||||
if !lhsIsPartialPair && !rhsIsPartialPair &&
|
||||
!(rhsType typecast Some(lhsType))
|
||||
then
|
||||
addSemanticError(pos, s"${lhsType.getType}", s"${rhsType.getType}", InvalidTypeError.apply)
|
||||
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
case Read(pos, lhs) => findLhsType(pos, lhs) match {
|
||||
case Some(IntType) | Some(CharType) =>
|
||||
case lhsTypeOpt => addSemanticError(pos, s"${lhsTypeOpt.getOrElse(AnyType).getType}",
|
||||
"Invalid type for Read (must be int or char)", GeneralTypeError.apply)
|
||||
}
|
||||
|
||||
case Free(pos, expr) => findExprType(pos, expr) match {
|
||||
case Some(ArrayType(_)) | Some(PairType(_, _)) => ()
|
||||
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
|
||||
"Invalid type cannot be freed (must be a pair or array)", GeneralTypeError.apply)
|
||||
}
|
||||
|
||||
case Return(pos, expr) =>
|
||||
if funcReturnType.isEmpty then
|
||||
addSemanticError(pos, s"", "Return statement outside of function", PairAssignmentError.apply)
|
||||
|
||||
findExprType(pos, expr).foreach { exprType =>
|
||||
if !(exprType typecast funcReturnType) then
|
||||
addSemanticError(pos,
|
||||
s"unexpected ${exprType.getType}\n expected ${funcReturnType.getOrElse(AnyType).getType}",
|
||||
"Invalid return type", PairAssignmentError.apply)
|
||||
}
|
||||
|
||||
case Exit(pos, expr) => findExprType(pos, expr) match {
|
||||
case Some(IntType) => ()
|
||||
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
|
||||
"Exit expression must be an int", GeneralTypeError.apply)
|
||||
}
|
||||
|
||||
case Block(_, blockStmts) => blockStmts.foreach(validateStmt)
|
||||
|
||||
case If(pos, cond, thenCase, elseCase) => findExprType(pos, cond) match {
|
||||
case Some(BoolType) =>
|
||||
thenCase.foreach(validateStmt)
|
||||
elseCase.foreach(validateStmt)
|
||||
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
|
||||
"Invalid condition type in If (must be bool)", InvalidConditionError.apply)
|
||||
}
|
||||
|
||||
case While(pos, cond, body) => findExprType(pos, cond) match {
|
||||
case Some(BoolType) => body.foreach(validateStmt)
|
||||
case typeOpt => addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType}",
|
||||
"Invalid condition type in While (must be bool)", InvalidConditionError.apply)
|
||||
}
|
||||
}
|
||||
|
||||
def findExprType(pos: (Int, Int), expr: Expr)
|
||||
(using params: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Option[Type] = expr match {
|
||||
case Ident(_, id) =>
|
||||
globalMainEnv.lookupType(id).orElse(params.find(_.id.id == id).map(_.t))
|
||||
|
||||
case IntLiter(_, _) => Some(IntType)
|
||||
case StrLiter(_, _) => Some(StrType)
|
||||
case BoolLiter(_, _) => Some(BoolType)
|
||||
case CharLiter(_, _) => Some(CharType)
|
||||
case PairLiter(_) => Some(PairPlaceholder)
|
||||
|
||||
case Not(x) => if unaryOpTypeCheck(pos, x, List(BoolType)) then Some(BoolType) else None
|
||||
case Neg(x) => if unaryOpTypeCheck(pos, x, List(IntType)) then Some(IntType) else None
|
||||
case Len(x) => if unaryOpTypeCheck(pos, x, List(ArrayType(AnyType))) then Some(IntType) else None
|
||||
case Ord(x) => if unaryOpTypeCheck(pos, x, List(CharType)) then Some(IntType) else None
|
||||
case Chr(x) => if unaryOpTypeCheck(pos, x, List(IntType)) then Some(CharType) else None
|
||||
|
||||
case Div(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
|
||||
case Mul(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
|
||||
case Mod(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
|
||||
case Add(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
|
||||
case Sub(x, y) => if binOpTypeCheck(pos, x, y, List(IntType)) then Some(IntType) else None
|
||||
|
||||
case Geq(x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
|
||||
case Gt (x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
|
||||
case Lt (x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
|
||||
case Leq(x, y) => if binOpTypeCheck(pos, x, y, List(IntType, CharType)) then Some(BoolType) else None
|
||||
|
||||
case Eq (x, y) => if binOpTypeCheck(pos, x, y, List(AnyType)) then Some(BoolType) else None
|
||||
case Neq(x, y) => if binOpTypeCheck(pos, x, y, List(AnyType)) then Some(BoolType) else None
|
||||
|
||||
case And(x, y) => if binOpTypeCheck(pos, x, y, List(BoolType)) then Some(BoolType) else None
|
||||
case Or (x, y) => if binOpTypeCheck(pos, x, y, List(BoolType)) then Some(BoolType) else None
|
||||
|
||||
case PairElem(_, _, lhs) => findLhsType(pos, lhs)
|
||||
case ArrayElem(_, ident, exprs) => arrayAccessType(pos, ident, exprs)
|
||||
|
||||
case IfExpr(pos, cond, thenBranch, elseBranch) =>
|
||||
val thenType = findExprType(pos, thenBranch)
|
||||
val elseType = findExprType(pos, elseBranch)
|
||||
if (thenType.exists(_ typecast elseType)) then thenType
|
||||
else if (elseType.exists(_ typecast thenType)) then elseType
|
||||
else {
|
||||
addSemanticError(pos, s"${thenType.getOrElse(AnyType).getType}, ${elseType.getOrElse(AnyType).getType}",
|
||||
"Incompatible types in If-Expr branches", GeneralTypeError.apply)
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
def unaryOpTypeCheck(pos: (Int, Int), x: Expr, requiredTypes: List[Type])
|
||||
(using params: List[Param], globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Boolean = findExprType(pos, x) match {
|
||||
case Some(lhsType) if requiredTypes.exists(_ typecast Some(lhsType)) => true
|
||||
case typeOpt =>
|
||||
addSemanticError(pos, s"${typeOpt.getOrElse(AnyType).getType} ", "Invalid type in unary operator", InvalidTypeError.apply)
|
||||
false
|
||||
}
|
||||
|
||||
def binOpTypeCheck(pos: (Int, Int), x: Expr, y: Expr, requiredTypes: List[Type])
|
||||
(using params: List[Param], globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Boolean = {
|
||||
val lhsOpt = findExprType(pos, x)
|
||||
val rhsOpt = findExprType(pos, y)
|
||||
(lhsOpt, rhsOpt) match {
|
||||
case (Some(lhsType), Some(rhsType))
|
||||
if (requiredTypes.exists(_ typecast lhsOpt) || requiredTypes == List(AnyType))
|
||||
&& (lhsType typecast rhsOpt) => true
|
||||
case _ =>
|
||||
addSemanticError(pos, s"got ${lhsOpt.getOrElse(AnyType).getType}, ${rhsOpt.getOrElse(AnyType).getType}",
|
||||
s"Invalid types in binary operator", GeneralTypeError.apply)
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
def findRhsType(pos: (Int, Int), rhs: Rhs)
|
||||
(using funcMap: Map[String, FuncSymbol], params: List[Param] = Nil,
|
||||
globalMainEnv: GlobalMainEnv, errors: mutable.ListBuffer[SemanticError]):
|
||||
Option[Type] = rhs match {
|
||||
|
||||
case ArrayLiter(_, elems) =>
|
||||
val elemTypes = elems.flatMap(e => findExprType(pos, e))
|
||||
elemTypes.distinct match {
|
||||
case t :: Nil => Some(ArrayType(t))
|
||||
case Nil => Some(ArrayType(AnyType)) // empty array => AnyType
|
||||
case multiple =>
|
||||
val lcaType = multiple.reduceLeft { (t1, t2) =>
|
||||
if t1.typecast(Some(t2)) then t2
|
||||
else if t2.typecast(Some(t1)) then t1
|
||||
else
|
||||
addSemanticError(pos, s"[${elemTypes.map(elem => elem.getType).mkString(", ")}]", "Multiple incompatible types in array",
|
||||
GeneralTypeError.apply)
|
||||
AnyType
|
||||
}
|
||||
Some(ArrayType(lcaType))
|
||||
}
|
||||
|
||||
case NewPair(_, fst, snd) => (findExprType(pos, fst), findExprType(pos, snd)) match {
|
||||
case (Some(f), Some(s)) => Some(PairType(f, s))
|
||||
case _ =>
|
||||
addSemanticError(pos, s"", "One of those types is invalid for newpair",
|
||||
PairAssignmentError.apply)
|
||||
None
|
||||
}
|
||||
|
||||
case PairElem(_, selector, _) => findLhsType(pos, rhs.asInstanceOf[Lhs])
|
||||
|
||||
case Call(callPos, ident, args) =>
|
||||
val funcNonUniqueName = ident.id.replaceAll("#\\d+$", "")
|
||||
funcMap.get(ident.id) match {
|
||||
case Some(FuncSymbol(_, _, returnType, paramTypes, _)) =>
|
||||
if args.size != paramTypes.size then
|
||||
addSemanticError(callPos, s"", s"Function ${funcNonUniqueName} called with wrong # of arguments",
|
||||
FunctionCallError.apply)
|
||||
val argTypes = args.flatMap(findExprType(pos, _))
|
||||
argTypes.zip(paramTypes).foreach { case (aType, pType) =>
|
||||
if !(aType typecast Some(pType._type)) then
|
||||
addSemanticError(callPos, s"",
|
||||
s"Argument type ${aType.getType} does not match param type ${pType._type.getType}",
|
||||
FunctionCallError.apply)
|
||||
}
|
||||
Some(returnType)
|
||||
case None =>
|
||||
addSemanticError(callPos, s"", s"Function ${funcNonUniqueName} not found",
|
||||
FunctionCallError.apply)
|
||||
None
|
||||
}
|
||||
|
||||
case expr: Expr => findExprType(pos, expr)
|
||||
}
|
||||
|
||||
def findLhsType(pos: (Int, Int), lhs: Lhs)
|
||||
(using params: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Option[Type] = lhs match {
|
||||
case ident: Ident => findExprType(pos, ident)
|
||||
|
||||
case ArrayElem(_, arrayIdent, indices) => arrayAccessType(pos, arrayIdent, indices)
|
||||
|
||||
case PairElem(_, selector, subLhs) =>
|
||||
val subType = findLhsType(pos, subLhs)
|
||||
// If it's a partial pair with PairPlaceholder in one slot, allow for PairPlaceholder
|
||||
(selector, subType) match {
|
||||
case (Fst, Some(PairType(PairPlaceholder, _))) => Some(PairPlaceholder)
|
||||
case (Snd, Some(PairType(_, PairPlaceholder))) => Some(PairPlaceholder)
|
||||
case (Fst, Some(PairType(t:Type, _))) => Some(t)
|
||||
case (Snd, Some(PairType(_, t:Type))) => Some(t)
|
||||
case _ => Some(AnyType)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle array–element type extraction, checking index types.
|
||||
def arrayAccessType(pos: (Int, Int), ident: Ident, indices: List[Expr])
|
||||
(using params: List[Param] = Nil, globalMainEnv: GlobalMainEnv,
|
||||
errors: mutable.ListBuffer[SemanticError]): Option[Type] = findExprType(pos, ident) match {
|
||||
case Some(arrType) =>
|
||||
// Must all be Int
|
||||
val validIndices = indices.forall(findExprType(pos, _) contains IntType)
|
||||
if !validIndices then addSemanticError(pos, s"${arrType.getType}", "Array indices must be of type int",
|
||||
GeneralTypeError.apply)
|
||||
|
||||
// Peel off array layers
|
||||
val finalType = indices.foldLeft(arrType) {
|
||||
case (ArrayType(inner), _) => inner
|
||||
case (nonArr, _) =>
|
||||
addSemanticError(pos, s"$ident", "Too many indices for array", PairAssignmentError.apply)
|
||||
nonArr
|
||||
}
|
||||
Some(finalType)
|
||||
case None | Some(_) =>
|
||||
addSemanticError(pos, s"", "Invalid array access (identifier not found or not an array)",
|
||||
PairAssignmentError.apply)
|
||||
None
|
||||
}
|
||||
}
|
||||
466
src/main/wacc/frontend/syntax/ast.scala
Normal file
466
src/main/wacc/frontend/syntax/ast.scala
Normal file
@ -0,0 +1,466 @@
|
||||
package wacc.frontend.syntax
|
||||
|
||||
import parsley.generic.{ParserBridge1, ParserBridge2, ParserBridge3, ParserBridge4/*, ParserBridge5*/}
|
||||
|
||||
object ast {
|
||||
// ====================================================
|
||||
// PROGRAM STRUCTURE
|
||||
// ====================================================
|
||||
|
||||
// ⟨program⟩ ::= ‘begin’ ⟨func⟩* ⟨stmt⟩ ‘end’. Condenses <stmt> into List[Stmt]
|
||||
case class WProgram(pos: (Int, Int), imports: List[Import], funcs: List[Func], stats: List[Stmt])
|
||||
|
||||
// Import statement: ‘import’ ⟨string⟩
|
||||
case class Import(pos: (Int, Int), path: String)
|
||||
|
||||
// A function definition:
|
||||
// ⟨func⟩ ::= ⟨type⟩ ⟨ident⟩ ‘(’ ⟨param-list⟩? ‘)’ ‘is’ ⟨stmt⟩ ‘end’
|
||||
// ⟨param-list⟩ ::= ⟨param⟩ ( ‘,’ ⟨param⟩ )*
|
||||
case class Func(pos: (Int, Int), t: Type, id: Ident, args: List[Param], body: List[Stmt])
|
||||
|
||||
// A function parameter: ⟨param⟩ ::= ⟨type⟩ ⟨ident⟩
|
||||
case class Param(pos: (Int, Int), t: Type, id: Ident)
|
||||
|
||||
// ====================================================
|
||||
// STATEMENTS
|
||||
// ====================================================
|
||||
sealed trait Stmt
|
||||
|
||||
// A “skip” statement: ‘skip’
|
||||
case class Skip(pos: (Int, Int)) extends Stmt
|
||||
|
||||
// Variable declaration: ⟨type⟩ ⟨ident⟩ ‘=’ ⟨rvalue⟩
|
||||
case class Declare(pos: (Int, Int), t: Type, ident: Ident, rhs: Rhs) extends Stmt
|
||||
|
||||
// Assignment: ⟨lvalue⟩ ‘=’ ⟨rvalue⟩
|
||||
case class Assign(pos: (Int, Int), lhs: Lhs, rhs: Rhs) extends Stmt
|
||||
|
||||
// Read statement: ‘read’ ⟨lvalue⟩
|
||||
case class Read(pos: (Int, Int), lhs: Lhs) extends Stmt
|
||||
|
||||
// Free statement: ‘free’ ⟨expr⟩
|
||||
case class Free(pos: (Int, Int), expr: Expr) extends Stmt
|
||||
|
||||
// Return statement: ‘return’ ⟨expr⟩
|
||||
case class Return(pos: (Int, Int), expr: Expr) extends Stmt
|
||||
|
||||
// Exit statement: ‘exit’ ⟨expr⟩
|
||||
case class Exit(pos: (Int, Int), expr: Expr) extends Stmt
|
||||
|
||||
// Sealed trait for both prints (Used for code gen)
|
||||
sealed trait PrintBase extends Stmt {
|
||||
def pos: (Int, Int)
|
||||
def expr: Expr
|
||||
}
|
||||
|
||||
// Print statement: ‘print’ ⟨expr⟩
|
||||
case class Print(pos: (Int, Int), expr: Expr) extends PrintBase
|
||||
|
||||
// Println statement: ‘println’ ⟨expr⟩
|
||||
case class Println(pos: (Int, Int), expr: Expr) extends PrintBase
|
||||
|
||||
// If statement: ‘if’ ⟨expr⟩ ‘then’ ⟨stmt⟩ ‘else’ ⟨stmt⟩ ‘fi’
|
||||
case class If(pos: (Int, Int), expr: Expr,
|
||||
thenCase: List[Stmt], elseCase: List[Stmt]) extends Stmt
|
||||
|
||||
// While loop: ‘while’ ⟨expr⟩ ‘do’ ⟨stmt⟩ ‘done’
|
||||
case class While(pos: (Int, Int), expr: Expr, body: List[Stmt]) extends Stmt
|
||||
|
||||
// Block statement: ‘begin’ ⟨stmt⟩ ‘end’
|
||||
case class Block(pos: (Int, Int), stmts: List[Stmt]) extends Stmt
|
||||
|
||||
// ⟨stmt⟩ ‘;’ ⟨stmt⟩ is captured by List[Stmt]
|
||||
|
||||
// ====================================================
|
||||
// LVALUE AND RVALUE
|
||||
// ====================================================
|
||||
// Both can be expressions
|
||||
sealed trait Lhs
|
||||
sealed trait Rhs
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Expressions
|
||||
// ----------------------------------------------------
|
||||
|
||||
sealed trait Expr extends Rhs
|
||||
|
||||
// Literals
|
||||
case class IntLiter(pos: (Int, Int), value: BigInt) extends Expr
|
||||
case class BoolLiter(pos: (Int, Int), value: Boolean) extends Expr
|
||||
case class CharLiter(pos: (Int, Int), value: Char) extends Expr
|
||||
case class StrLiter(pos: (Int, Int), value: String) extends Expr
|
||||
|
||||
// The ‘null‘ literal (for pairs)
|
||||
case class PairLiter(pos: (Int, Int)) extends Expr
|
||||
|
||||
// Array literal: ⟨array-liter⟩ ::= ‘[’ ( ⟨expr⟩ (‘,’ ⟨expr⟩)* )? ‘]’
|
||||
case class ArrayLiter(pos: (Int, Int), elems: List[Expr]) extends Rhs
|
||||
|
||||
// ‘newpair’ ‘(’ ⟨expr⟩ ‘,’ ⟨expr⟩ ‘)’
|
||||
case class NewPair(pos: (Int, Int), fst: Expr, snd: Expr) extends Rhs
|
||||
|
||||
// Function call: ‘call’ ⟨ident⟩ ‘(’ ⟨arg-list⟩? ‘)’
|
||||
case class Call(pos: (Int, Int), id: Ident, args: List[Expr]) extends Rhs
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Operators
|
||||
// ----------------------------------------------------
|
||||
|
||||
// Descending Order of Precedence, with top paragraph being most tightly binding
|
||||
|
||||
// Unary operators
|
||||
|
||||
sealed trait UnOp extends Expr {
|
||||
def expr: Expr
|
||||
}
|
||||
|
||||
case class IfExpr(pos: (Int, Int), cond: Expr, thenBranch: Expr, elseBranch: Expr) extends Expr
|
||||
|
||||
// Prefix
|
||||
case class Not(expr: Expr) extends UnOp
|
||||
case class Neg(expr: Expr) extends UnOp
|
||||
case class Len(expr: Expr) extends UnOp
|
||||
case class Ord(expr: Expr) extends UnOp
|
||||
case class Chr(expr: Expr) extends UnOp
|
||||
|
||||
// Binary operators
|
||||
sealed trait BinOp extends Expr {
|
||||
def lhs: Expr
|
||||
def rhs: Expr
|
||||
}
|
||||
|
||||
sealed trait DivOp extends BinOp
|
||||
sealed trait CompOp extends BinOp
|
||||
|
||||
// infix left
|
||||
case class Mul(lhs: Expr, rhs: Expr) extends BinOp
|
||||
|
||||
case class Div(lhs: Expr, rhs: Expr) extends DivOp
|
||||
case class Mod(lhs: Expr, rhs: Expr) extends DivOp
|
||||
|
||||
// infix left
|
||||
case class Add(lhs: Expr, rhs: Expr) extends BinOp
|
||||
case class Sub(lhs: Expr, rhs: Expr) extends BinOp
|
||||
|
||||
// infix non
|
||||
case class Gt(lhs: Expr, rhs: Expr) extends CompOp
|
||||
case class Geq(lhs: Expr, rhs: Expr) extends CompOp
|
||||
case class Lt(lhs: Expr, rhs: Expr) extends CompOp
|
||||
case class Leq(lhs: Expr, rhs: Expr) extends CompOp
|
||||
|
||||
// infix non
|
||||
case class Eq(lhs: Expr, rhs: Expr) extends CompOp
|
||||
case class Neq(lhs: Expr, rhs: Expr) extends CompOp
|
||||
|
||||
// infix right
|
||||
case class And(lhs: Expr, rhs: Expr) extends BinOp
|
||||
case class Or(lhs: Expr, rhs: Expr) extends BinOp
|
||||
|
||||
// ----------------------------------------------------
|
||||
// LHS: Left-hand sides for assignments
|
||||
// ----------------------------------------------------
|
||||
// An identifier: ⟨ident⟩
|
||||
case class Ident(pos: (Int, Int), id: String) extends Lhs with Expr {
|
||||
override def equals(obj: Any): Boolean = obj match {
|
||||
case Ident(_, id2) => this.id == id2
|
||||
case _ => false
|
||||
}
|
||||
override def hashCode(): Int = id.hashCode
|
||||
}
|
||||
|
||||
// An array element access: ⟨array-elem⟩
|
||||
case class ArrayElem(pos: (Int, Int), id: Ident, indices: List[Expr]) extends Lhs with Expr
|
||||
|
||||
// Pair element: ⟨pair-elem⟩
|
||||
sealed trait PairSelector
|
||||
case object Fst extends PairSelector
|
||||
case object Snd extends PairSelector
|
||||
// A pair element is both an lvalue and an expression:
|
||||
// ⟨pair-elem⟩ ::= ‘fst’ ⟨lvalue⟩ | ‘snd’ ⟨lvalue⟩
|
||||
case class PairElem(pos: (Int, Int), selector: PairSelector, expr: Lhs) extends Lhs with Expr
|
||||
|
||||
// ====================================================
|
||||
// TYPES
|
||||
// ====================================================
|
||||
|
||||
// For pair types, the elements in the declarations must be either a base type or an array type
|
||||
// as (Nested pairs are “erased”).
|
||||
// A separate trait to mark types that can appear as elements of a pair.
|
||||
sealed trait PairElemType {
|
||||
|
||||
/**
|
||||
* A helper method to compare array subtypes. It returns true if x and y are exactly the same,
|
||||
* or if x and y are further pair-compatible.
|
||||
*/
|
||||
private def arrayCompatible(x: PairElemType, y: PairElemType): Boolean =
|
||||
(x == y) || ((x, y) match {
|
||||
case (a: PairElemType, b: PairElemType) => a.paircast(Some(b))
|
||||
})
|
||||
|
||||
/**
|
||||
* Attempt to “typecast” this type into the other (Option[PairElemType]).
|
||||
* For example, 'AnyType' can match anything, an Array of 'AnyType' can
|
||||
* match any array, etc. If the other is None, returns false.
|
||||
*/
|
||||
infix def typecast(other: Option[PairElemType]): Boolean = other match {
|
||||
case None => false
|
||||
|
||||
case Some(o) => (this, o) match {
|
||||
// Pair placeholders can match a concrete pair (and vice versa)
|
||||
case (PairType(_, _), PairPlaceholder) | (PairPlaceholder, PairType(_, _)) => true
|
||||
|
||||
// For two actual PairTypes, defer to paircast
|
||||
case (PairType(_, _), PairType(_, _)) => this.paircast(other)
|
||||
|
||||
// An array of AnyType is compatible with any array
|
||||
case (ArrayType(AnyType), ArrayType(_)) => true
|
||||
|
||||
// Compare arrays element–wise (including deeper pair casting)
|
||||
case (ArrayType(x), ArrayType(y)) => arrayCompatible(x, y)
|
||||
|
||||
// "char[]" can be cast to "string"
|
||||
case (ArrayType(CharType), StrType) => true
|
||||
|
||||
// "AnyType" matches anything, or anything can be cast to "AnyType"
|
||||
case (AnyType, _) => true
|
||||
case (_, AnyType) => true
|
||||
|
||||
// Otherwise, check for exact equality
|
||||
case _ => this == o
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A specialized version of casting that strictly compares pairs. If it is not
|
||||
* a pair scenario or an exactly matching type, this returns false.
|
||||
*/
|
||||
infix def paircast(other: Option[PairElemType]): Boolean = other match {
|
||||
case None => false
|
||||
case Some(o) => (this, o) match {
|
||||
case (PairType(f1, s1), PairType(f2, s2)) =>
|
||||
(f1.paircast(Some(f2))) && (s1.paircast(Some(s2)))
|
||||
|
||||
// Pair placeholders can match any concrete PairType
|
||||
case (PairType(_, _), PairPlaceholder) => true
|
||||
case (PairPlaceholder, PairType(_, _)) => true
|
||||
|
||||
// Array[AnyType] can match any array
|
||||
case (ArrayType(AnyType), ArrayType(_)) => true
|
||||
|
||||
// For two arrays, they must be exactly the same. (No 'AnyType' logic here.)
|
||||
case (ArrayType(x), ArrayType(y)) => x == y
|
||||
|
||||
// Otherwise, they must be exactly equal
|
||||
case _ => this == o
|
||||
}
|
||||
}
|
||||
|
||||
def getType: String
|
||||
}
|
||||
|
||||
sealed trait Type extends PairElemType
|
||||
|
||||
case object AnyType extends Type { override def getType: String = "any"}
|
||||
|
||||
// Base types
|
||||
case object IntType extends Type {
|
||||
override def getType = "int"
|
||||
}
|
||||
case object BoolType extends Type { override def getType = "bool" }
|
||||
case object CharType extends Type { override def getType = "char" }
|
||||
case object StrType extends Type { override def getType = "string" }
|
||||
|
||||
// Array type: any type followed by [].
|
||||
case class ArrayType(elemType: Type) extends Type {
|
||||
override def getType = s"${elemType.getType}[]"
|
||||
}
|
||||
|
||||
// I think this is needed for placeholder pair types?
|
||||
// e.g: pair(int, pair) is valid but pair(pair(int, char), bool) is not
|
||||
case object PairPlaceholder extends Type {
|
||||
override def getType = "pair"
|
||||
}
|
||||
|
||||
// A fully–specified pair type.
|
||||
case class PairType(fst: PairElemType, snd: PairElemType) extends Type {
|
||||
// Custom equality: AnyType acts as a wildcard
|
||||
override def equals(obj: Any): Boolean = obj match {
|
||||
case PairType(f1, s1) =>
|
||||
(fst == AnyType || fst == f1) && (snd == AnyType || snd == s1)
|
||||
case _ => false
|
||||
}
|
||||
|
||||
// Ensure that equal objects have the same hash code
|
||||
override def hashCode(): Int = {
|
||||
(if (fst == AnyType) 0 else fst.hashCode) ^
|
||||
(if (snd == AnyType) 0 else snd.hashCode)
|
||||
}
|
||||
|
||||
override def getType = s"pair(${fst.getType}, ${snd.getType})"
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
// PARSER BRIDGES
|
||||
// ====================================================
|
||||
|
||||
// We can have null extending base type or as a literal extending expr im not sure yet
|
||||
// case object Null extends BaseType {override def getType() = "null"}
|
||||
|
||||
object WProgram extends ParserBridge4[(Int, Int), List[Import], List[Func], List[Stmt], WProgram]
|
||||
object Import extends ParserBridge2[(Int, Int), String, Import]
|
||||
object Func extends ParserBridge3[(((Int, Int), Type), Ident), List[Param], List[Stmt], Func] {
|
||||
def apply(tupleOfPosTypeIdent: (((Int, Int), Type), Ident), params: List[Param], body: List[Stmt]):
|
||||
Func = {
|
||||
|
||||
val ((pos, _type), ident) = tupleOfPosTypeIdent
|
||||
new Func(pos, _type, ident, params, body)
|
||||
}
|
||||
override def labels: List[String] = List("function declaration")
|
||||
}
|
||||
object Param extends ParserBridge3[(Int, Int), Type, Ident, Param]
|
||||
|
||||
object Skip extends ParserBridge2[(Int, Int), Any, Skip] {
|
||||
override def apply(pos: (Int, Int), _x: Any): Skip = Skip(pos)
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Declare extends ParserBridge4[(Int, Int), Type, Ident, Rhs, Declare] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Assign extends ParserBridge3[(Int, Int), Lhs, Rhs, Assign] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Read extends ParserBridge2[(Int, Int), Lhs, Read] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Free extends ParserBridge2[(Int, Int), Expr, Free] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Return extends ParserBridge2[(Int, Int), Expr, Return] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Exit extends ParserBridge2[(Int, Int), Expr, Exit] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Print extends ParserBridge2[(Int, Int), Expr, Print] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Println extends ParserBridge2[(Int, Int), Expr, Println] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object If extends ParserBridge4[(Int, Int), Expr, List[Stmt], List[Stmt], If] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object While extends ParserBridge3[(Int, Int), Expr, List[Stmt], While] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
object Block extends ParserBridge2[(Int, Int), List[Stmt], Block] {
|
||||
override def labels: List[String] = List("statement")
|
||||
}
|
||||
|
||||
object ArrayElem extends ParserBridge3[((Int, Int), Ident), Expr, List[Expr], ArrayElem] {
|
||||
def apply(tupleOfPosIdent: ((Int, Int), Ident), head: Expr, rest: List[Expr]): ArrayElem = {
|
||||
val (pos, ident) = tupleOfPosIdent
|
||||
|
||||
// Prepend the rest with the head
|
||||
new ArrayElem(pos, ident, head +: rest)
|
||||
}
|
||||
}
|
||||
object Ident extends ParserBridge2[(Int, Int), String, Ident] {
|
||||
override def labels: List[String] = List("identifier")
|
||||
}
|
||||
object IntLiter extends ParserBridge2[(Int, Int), BigInt, IntLiter] {
|
||||
override def labels: List[String] = List("integer literal")
|
||||
}
|
||||
object BoolLiter extends ParserBridge2[(Int, Int), Boolean, BoolLiter] {
|
||||
override def labels: List[String] = List("boolean literal")
|
||||
}
|
||||
object CharLiter extends ParserBridge2[(Int, Int), Char, CharLiter] {
|
||||
override def labels: List[String] = List("character literal")
|
||||
}
|
||||
object StrLiter extends ParserBridge2[(Int, Int), String, StrLiter] {
|
||||
override def labels: List[String] = List("string literal")
|
||||
}
|
||||
object PairLiter extends ParserBridge2[(Int, Int), Any, PairLiter] {
|
||||
override def apply(pos: (Int, Int), _x: Any): PairLiter = PairLiter(pos)
|
||||
override def labels: List[String] = List("pair literal")
|
||||
}
|
||||
|
||||
object ArrayLiter extends ParserBridge2[(Int, Int), List[Expr], ArrayLiter] {
|
||||
override def labels: List[String] = List("array literal")
|
||||
}
|
||||
object NewPair extends ParserBridge3[(Int, Int), Expr, Expr, NewPair]
|
||||
object Call extends ParserBridge3[(Int, Int), Ident, List[Expr], Call] {
|
||||
override def labels: List[String] = List("function call")
|
||||
}
|
||||
|
||||
object Not extends ParserBridge1[Expr, Not] {
|
||||
override def labels: List[String] = List("expression")
|
||||
}
|
||||
object Neg extends ParserBridge1[Expr, Neg] {
|
||||
override def labels: List[String] = List("expression")
|
||||
}
|
||||
object Len extends ParserBridge1[Expr, Len] {
|
||||
override def labels: List[String] = List("expression")
|
||||
}
|
||||
object Ord extends ParserBridge1[Expr, Ord] {
|
||||
override def labels: List[String] = List("expression")
|
||||
}
|
||||
object Chr extends ParserBridge1[Expr, Chr] {
|
||||
override def labels: List[String] = List("expression")
|
||||
}
|
||||
object IfExpr extends ParserBridge4[(Int, Int), Expr, Expr, Expr, IfExpr] {
|
||||
override def labels: List[String] = List("expression")
|
||||
}
|
||||
object Mul extends ParserBridge2[Expr, Expr, Mul] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Div extends ParserBridge2[Expr, Expr, Div] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Mod extends ParserBridge2[Expr, Expr, Mod] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
|
||||
object Add extends ParserBridge2[Expr, Expr, Add] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Sub extends ParserBridge2[Expr, Expr, Sub] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
|
||||
object Gt extends ParserBridge2[Expr, Expr, Gt] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Geq extends ParserBridge2[Expr, Expr, Geq] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Lt extends ParserBridge2[Expr, Expr, Lt] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Leq extends ParserBridge2[Expr, Expr, Leq] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
|
||||
object Eq extends ParserBridge2[Expr, Expr, Eq] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Neq extends ParserBridge2[Expr, Expr, Neq] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
|
||||
object And extends ParserBridge2[Expr, Expr, And] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
object Or extends ParserBridge2[Expr, Expr, Or] {
|
||||
override def labels: List[String] = List("binary operator")
|
||||
}
|
||||
|
||||
object PairType extends ParserBridge2[PairElemType, PairElemType, PairType]
|
||||
object PairElem extends ParserBridge2[((Int, Int), PairSelector), Lhs, PairElem] {
|
||||
def apply(tupleOfPosSelector: ((Int, Int), PairSelector), lhs: Lhs): PairElem = {
|
||||
val (pos, selector) = tupleOfPosSelector
|
||||
new PairElem(pos, selector, lhs)
|
||||
}
|
||||
}
|
||||
object ArrayType extends ParserBridge1[Type, ArrayType]
|
||||
}
|
||||
40
src/main/wacc/frontend/syntax/imports.scala
Normal file
40
src/main/wacc/frontend/syntax/imports.scala
Normal file
@ -0,0 +1,40 @@
|
||||
package wacc.frontend.syntax
|
||||
|
||||
import ast._
|
||||
import scala.concurrent._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
|
||||
class ImportHandler {
|
||||
// given a WProgram, add all the functions from the imports to the program
|
||||
def addImports(program: WProgram, optimiseFlag: Boolean = false): WProgram = {
|
||||
// If there are less than 4 imports, we can just process them sequentially, as the overhead of parallelism is not worth it
|
||||
// This will be evaluated at each level, so nested imports will be processed in parallel if they exceed 3 imports, even if
|
||||
// the top level does not
|
||||
val minInputs = 4
|
||||
val importFuncs = if (!optimiseFlag || program.imports.length < minInputs) {
|
||||
program.imports.map(_.path).map(addProgram).map(_.funcs).flatten
|
||||
} else {
|
||||
// Convert imports into a list of futures to be executed in parallel
|
||||
val futures = program.imports.map(_.path).map { path =>
|
||||
Future(addProgram(path))
|
||||
}
|
||||
|
||||
// Wait for all Futures to complete
|
||||
Await.result(Future.sequence(futures), scala.concurrent.duration.Duration.Inf)
|
||||
.flatMap(_.funcs)
|
||||
}
|
||||
|
||||
program.copy(funcs = program.funcs ++ importFuncs)
|
||||
}
|
||||
|
||||
// given a file path, parse the file and return the WProgram
|
||||
def addProgram(path: String): WProgram = {
|
||||
parser.parseFile(path) match {
|
||||
case parsley.Failure(msg) => {
|
||||
println(s"\nSyntax Error $msg in file $path")
|
||||
sys.exit(SYNTAX_ERROR_CODE)
|
||||
}
|
||||
case parsley.Success(progFromSyntax: WProgram) => addImports(progFromSyntax)
|
||||
}
|
||||
}
|
||||
}
|
||||
69
src/main/wacc/frontend/syntax/lexer.scala
Normal file
69
src/main/wacc/frontend/syntax/lexer.scala
Normal file
@ -0,0 +1,69 @@
|
||||
package wacc.frontend.syntax
|
||||
|
||||
import parsley.Parsley
|
||||
import parsley.token.{Lexer, Basic}
|
||||
import parsley.token.descriptions._
|
||||
import parsley.token.errors._
|
||||
import parsley.token.Unicode
|
||||
|
||||
def isEnglishLetter(c: Char): Boolean =
|
||||
c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
|
||||
|
||||
def isEnglishLetterOrDigit(c: Char): Boolean =
|
||||
isEnglishLetter(c) || c.isDigit
|
||||
|
||||
val errorConfig = new ErrorConfig {
|
||||
override def labelSymbol: Map[String, LabelWithExplainConfig] = Map(
|
||||
"[" -> Label("array index"),
|
||||
"=" -> Label("assignment")
|
||||
)
|
||||
}
|
||||
object lexer {
|
||||
val excludedChars = Set('\"', '\'', '\\')
|
||||
val keys = Set(
|
||||
"begin", "end", "is", "skip", "read", "free", "return", "exit",
|
||||
"print", "println", "if", "then", "else", "fi", "while", "do", "done",
|
||||
"fst", "snd", "newpair", "call", "int", "bool", "char", "string",
|
||||
"pair", "true", "false", "null", "len", "ord", "chr"
|
||||
)
|
||||
|
||||
val ops = Set(
|
||||
"=", "!", "-", "*", "/", "%", "+", "-",
|
||||
">", ">=", "<", "<=", "==", "!=", "&&", "||"
|
||||
)
|
||||
|
||||
private val desc = LexicalDesc.plain.copy(
|
||||
nameDesc = NameDesc.plain.copy(
|
||||
identifierStart = Basic(char => isEnglishLetter(char) || char == '_'),
|
||||
identifierLetter = Basic(char => isEnglishLetterOrDigit(char) || char == '_')
|
||||
),
|
||||
symbolDesc = SymbolDesc.plain.copy(
|
||||
hardKeywords = keys,
|
||||
hardOperators = ops,
|
||||
),
|
||||
textDesc = TextDesc.plain.copy(
|
||||
escapeSequences = EscapeDesc.plain.copy(
|
||||
literals = Set('\\', '\'', '\"'),
|
||||
mapping = Map("0" -> 0x00, "b" -> 0x08, "t" -> 0x09, "n" -> 0x0a, "f" -> 0x0c,
|
||||
"r" -> 0x0d)
|
||||
),
|
||||
graphicCharacter = Unicode(char => char >= ' '.toInt && !excludedChars.contains(char.toChar)
|
||||
)
|
||||
),
|
||||
spaceDesc = SpaceDesc.plain.copy(
|
||||
lineCommentStart = "#"
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
private val lexer = new Lexer(desc, errorConfig)
|
||||
|
||||
// Define the primitive atoms of the grammar (done in terms of Scala types)
|
||||
val integer = lexer.lexeme.integer.decimal32[BigInt]
|
||||
val character = lexer.lexeme.character.ascii
|
||||
val string = lexer.lexeme.string.ascii
|
||||
val implicits = lexer.lexeme.symbol.implicits
|
||||
val identifier = lexer.lexeme.names.identifier
|
||||
|
||||
def fully[A](p: Parsley[A]): Parsley[A] = lexer.fully(p)
|
||||
}
|
||||
259
src/main/wacc/frontend/syntax/parser.scala
Normal file
259
src/main/wacc/frontend/syntax/parser.scala
Normal file
@ -0,0 +1,259 @@
|
||||
package wacc.frontend.syntax
|
||||
|
||||
import ast._
|
||||
import lexer._
|
||||
import lexer.implicits.implicitSymbol
|
||||
import parsley.{Parsley, Result}
|
||||
import parsley.combinator.{sepBy, sepBy1, ifS}
|
||||
import parsley.errors.combinator._
|
||||
import parsley.expr.{precedence, Ops, Prefix, InfixL, InfixN, InfixR}
|
||||
import parsley.position.pos
|
||||
import Parsley.{many, atomic, pure, empty}
|
||||
import scala.util.{Success, Failure}
|
||||
import java.io.File
|
||||
|
||||
object parser {
|
||||
def parseFile(input: String): Result[String, WProgram] = parser.parseFile(new File(input)) match {
|
||||
case Success(res) => res
|
||||
case Failure(_) => parsley.Failure("File not recognised")
|
||||
}
|
||||
def parse(input: String): Result[String, WProgram] = parser.parse(input)
|
||||
private val parser = fully(prog)
|
||||
|
||||
// --------------------------
|
||||
// PROGRAM
|
||||
// --------------------------
|
||||
|
||||
// prog ::= "begin" funcs* stmts "end"
|
||||
private lazy val prog: Parsley[WProgram] =
|
||||
WProgram(pos, many(_import), "begin" ~> many(func), stmts <~ "end")
|
||||
|
||||
// --------------------------
|
||||
// IMPORTS
|
||||
// --------------------------
|
||||
|
||||
// import ::= "import" string
|
||||
private lazy val _import: Parsley[Import] = Import(pos, "import" ~> string)
|
||||
|
||||
// --------------------------
|
||||
// FUNCTIONS
|
||||
// --------------------------
|
||||
|
||||
// func ::= <type> <ident> "(" paramList? ")" "is" stmts "end"
|
||||
private lazy val func: Parsley[Func] =
|
||||
Func(atomic(pos <~> _type <~> ident <~ "("), paramList <~ ")", "is" ~> stmts.filter(noReturns).explain(returnExplain) <~ "end")
|
||||
|
||||
// paramList ::= param ("," param)*
|
||||
private lazy val paramList: Parsley[List[Param]] = sepBy(param, ",")
|
||||
|
||||
// param ::= <type> <ident>
|
||||
private lazy val param: Parsley[Param] = Param(pos, _type, ident)
|
||||
|
||||
// predicate which returns true if a body returns/exits and false otherwise
|
||||
private def noReturns(stmts: List[Stmt]): Boolean = stmts.lastOption match {
|
||||
case Some(Return(_, _)) => true
|
||||
case Some(Exit(_, _)) => true
|
||||
case Some(While(_, _, body)) => noReturns(body)
|
||||
case Some(If(_, _, thenStmts, elseStmts)) => noReturns(thenStmts) &&
|
||||
noReturns(elseStmts)
|
||||
case Some(Block(_, stats)) => noReturns(stats)
|
||||
case _ => false
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// IDENTIFIERS
|
||||
// --------------------------
|
||||
|
||||
private lazy val ident: Parsley[Ident] = Ident(pos, identifier)
|
||||
private lazy val intLiter: Parsley[IntLiter] = IntLiter(pos, integer)
|
||||
private lazy val boolLiter: Parsley[BoolLiter] =
|
||||
BoolLiter(pos, "true" ~> pure(true) | "false" ~> pure(false))
|
||||
private lazy val charLiter: Parsley[CharLiter] = CharLiter(pos, character)
|
||||
private lazy val strLiter: Parsley[StrLiter] = StrLiter(pos, string)
|
||||
private lazy val pairLiter: Parsley[PairLiter] = PairLiter(pos, "null")
|
||||
private lazy val ifExpr: Parsley[IfExpr] =
|
||||
IfExpr(pos, "if" ~> expr, "then" ~> expr, ("else"| _semiCheck) ~> expr <~ "fi")
|
||||
|
||||
|
||||
// --------------------------
|
||||
// TYPES
|
||||
// --------------------------
|
||||
|
||||
// simpleType ::= baseType | pairType | "(" _type ")"
|
||||
// created to break the left recursion in type
|
||||
private lazy val simpleType: Parsley[Type] =
|
||||
baseType
|
||||
| pairType
|
||||
| ("(" ~> _type <~ ")")
|
||||
|
||||
// baseType ::= "int" | "bool" | "char" | "string"
|
||||
private lazy val baseType: Parsley[Type] =
|
||||
("int" as IntType)
|
||||
| ("bool" as BoolType)
|
||||
| ("char" as CharType)
|
||||
| ("string" as StrType)
|
||||
|
||||
// pairType ::= "pair" "(" pairElemType "," pairElemType ")"
|
||||
private lazy val pairType: Parsley[Type] =
|
||||
PairType("pair" ~> ("(" ~> pairElemType), "," ~> pairElemType <~ ")")
|
||||
|
||||
// arraySuffix parses "[]" and returns a function wrapping a type in ArrayType
|
||||
private lazy val arraySuffix: Parsley[Type => Type] =
|
||||
ArrayType from ("[" ~> "]")
|
||||
|
||||
// _type ::= simpleType (arraySuffix)*
|
||||
private lazy val _type: Parsley[Type] =
|
||||
simpleType.flatMap {t =>
|
||||
many(arraySuffix).map(fs => fs.foldLeft(t)((acc, f) => f(acc)))
|
||||
}
|
||||
|
||||
// pairElemType: a base (or array) type for pair elements, or the literal "pair" for erasure
|
||||
private lazy val pairElemType: Parsley[PairElemType] =
|
||||
(baseType.flatMap {t =>
|
||||
many(arraySuffix).map(fs => fs.foldLeft(t)((acc, f) => f(acc)))
|
||||
})
|
||||
| ("pair" as PairPlaceholder)
|
||||
|
||||
// --------------------------
|
||||
// EXPRESSIONS
|
||||
// --------------------------
|
||||
|
||||
// Expression parser with precedence handling
|
||||
private lazy val expr: Parsley[Expr] =
|
||||
precedence(
|
||||
(intLiter
|
||||
| boolLiter
|
||||
| charLiter
|
||||
| strLiter
|
||||
| pairLiter
|
||||
| arrayElem
|
||||
| ident
|
||||
| ifExpr
|
||||
| ("(" ~> expr <~ ")")).label("expression").explain(expressionExplain))
|
||||
(
|
||||
// If the token is an integer, return the empty combinator to escape this case so that
|
||||
// it will be handled elsewhere
|
||||
// Otherwise, wrap the expression (which is not a pure integer) in the Neg object
|
||||
Ops(Prefix)(atomic(ifS(atomic(integer.hide) ~> pure(true)| pure(false),
|
||||
empty, Neg from "-"))),
|
||||
Ops(Prefix)(Not from "!"),
|
||||
Ops(Prefix)(Len from "len"),
|
||||
Ops(Prefix)(Ord from "ord"),
|
||||
Ops(Prefix)(Chr from "chr"),
|
||||
|
||||
Ops(InfixL)(Mul from "*"),
|
||||
Ops(InfixL)(Mod from "%"),
|
||||
Ops(InfixL)(Div from "/"),
|
||||
|
||||
Ops(InfixL)(Add from "+"),
|
||||
Ops(InfixL)(Sub from "-"),
|
||||
|
||||
Ops(InfixN)(Gt from ">"),
|
||||
Ops(InfixN)(Geq from ">="),
|
||||
Ops(InfixN)(Lt from "<"),
|
||||
Ops(InfixN)(Leq from "<="),
|
||||
|
||||
Ops(InfixN)(Eq from "=="),
|
||||
Ops(InfixN)(Neq from "!="),
|
||||
|
||||
Ops(InfixR)(And from "&&"),
|
||||
Ops(InfixR)(Or from "||")
|
||||
)
|
||||
|
||||
// --------------------------
|
||||
// STATEMENTS
|
||||
// --------------------------
|
||||
|
||||
private lazy val stmt: Parsley[Stmt] =
|
||||
skipStmt
|
||||
| declareStmt
|
||||
| assignStmt
|
||||
| readStmt
|
||||
| freeStmt
|
||||
| returnStmt
|
||||
| exitStmt
|
||||
| printStmt
|
||||
| printlnStmt
|
||||
| ifStmt
|
||||
| whileStmt
|
||||
| blockStmt
|
||||
|
||||
// stmts ::= stmt (";" stmt)*
|
||||
private lazy val stmts: Parsley[List[Stmt]] = sepBy1(stmt, ";")
|
||||
|
||||
// "skip" statement.
|
||||
private lazy val skipStmt: Parsley[Skip] = Skip(pos, "skip")
|
||||
|
||||
// Declaration: <type> <ident> "=" rVal
|
||||
private lazy val declareStmt: Parsley[Declare] = Declare(pos, _type, ident <~ "=", rVal)
|
||||
|
||||
// Assignment: lVal "=" rVal
|
||||
private lazy val assignStmt: Parsley[Assign] = Assign(pos, lVal <~ "=", rVal)
|
||||
|
||||
// Read statement: "read" lVal
|
||||
private lazy val readStmt: Parsley[Read] = Read(pos, "read" ~> lVal)
|
||||
|
||||
// Free statement: "free" expr
|
||||
private lazy val freeStmt: Parsley[Free] = Free(pos, "free" ~> expr)
|
||||
|
||||
// Return statement: "return" expr
|
||||
private lazy val returnStmt: Parsley[Return] = Return(pos, "return" ~> expr)
|
||||
|
||||
// Exit statement: "exit" expr
|
||||
private lazy val exitStmt: Parsley[Exit] = Exit(pos, "exit" ~> expr)
|
||||
|
||||
// Print statement: "print" expr
|
||||
private lazy val printStmt: Parsley[Print] = Print(pos, "print" ~> expr)
|
||||
|
||||
// Println statement: "println" expr
|
||||
private lazy val printlnStmt: Parsley[Println] = Println(pos, "println" ~> expr)
|
||||
|
||||
// If statement: "if" expr "then" stmts "else" stmts "fi"
|
||||
private lazy val whileStmt: Parsley[While] = While(pos, "while" ~> expr, "do" ~> stmts <~ "done")
|
||||
|
||||
// While statement: "while" expr "do" stmts "done"
|
||||
private lazy val ifStmt: Parsley[If] =
|
||||
If(pos, "if" ~> expr, "then" ~> stmts, ("else"| _semiCheck) ~> stmts <~ "fi")
|
||||
|
||||
// Block statement: "begin" stmts "end"
|
||||
private lazy val blockStmt: Parsley[Block] = Block(pos, "begin" ~> stmts <~ "end")
|
||||
|
||||
// --------------------------
|
||||
// L-VALUES and R-VALUES
|
||||
// --------------------------
|
||||
|
||||
// lVal ::= arrayElem | pairElem | lident
|
||||
private lazy val lVal: Parsley[Lhs] =
|
||||
arrayElem
|
||||
| pairElem
|
||||
| ident
|
||||
|
||||
// arrayElem ::= lident ("[" expr "]")+
|
||||
private lazy val arrayElem: Parsley[ArrayElem] = ArrayElem(atomic(pos <~> ident <~ "["), expr <~ "]", many("[" ~> expr <~ "]"))
|
||||
|
||||
// pairElem ::= ("fst" | "snd") lVal
|
||||
private lazy val pairElem: Parsley[PairElem] =
|
||||
PairElem(atomic(pos <~> (("fst" as Fst) | ("snd" as Snd))), lVal)
|
||||
|
||||
// rVal ::= expr | arrayLiter | newPair | pairElem | call
|
||||
private lazy val rVal: Parsley[Rhs] =
|
||||
expr
|
||||
| arrayLiter
|
||||
| newPair
|
||||
| pairElem
|
||||
| call
|
||||
|
||||
// arrayLiter ::= "[" exprs "]"
|
||||
private lazy val arrayLiter: Parsley[ArrayLiter] =
|
||||
ArrayLiter(pos, "[" ~> exprs <~ "]")
|
||||
|
||||
// newPair ::= "newpair" "(" expr "," expr ")"
|
||||
private lazy val newPair: Parsley[NewPair] =
|
||||
NewPair(pos, "newpair" ~> ("(" ~> expr), "," ~> expr <~ ")")
|
||||
|
||||
// call ::= "call" lident "(" exprs ")"
|
||||
private lazy val call: Parsley[Call] = Call(pos, "call" ~> ident, "(" ~> exprs <~ ")")
|
||||
|
||||
// exprs ::= (expr ("," expr)*)?
|
||||
private lazy val exprs: Parsley[List[Expr]] = sepBy(expr, ",")
|
||||
}
|
||||
17
src/main/wacc/frontend/syntax/syntaxErrors.scala
Normal file
17
src/main/wacc/frontend/syntax/syntaxErrors.scala
Normal file
@ -0,0 +1,17 @@
|
||||
package wacc.frontend.syntax
|
||||
|
||||
import parsley.character.char
|
||||
import parsley.errors.patterns.VerifiedErrors
|
||||
|
||||
val SYNTAX_ERROR_CODE = 100
|
||||
|
||||
val _semiCheck = char(';').verifiedExplain("semi-colons cannot be written between `if` and `else`")
|
||||
|
||||
// Defining explanations for constructs that yield the same error explanation message
|
||||
val expressionExplain = "expressions may start with integer, string, character or boolean\n" +
|
||||
" literals; identifiers; unary operators; null; or parentheses.\n" +
|
||||
" In addition, expressions may contain array indexing operations; and\n" +
|
||||
" comparison, logical, and arithmetic operators"
|
||||
|
||||
val bodyEndExplain = "all programs must be enclosed in begin ... end"
|
||||
val returnExplain = "Missing a return on all exit paths"
|
||||
114
src/test/wacc/codeGen.scala
Normal file
114
src/test/wacc/codeGen.scala
Normal file
@ -0,0 +1,114 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
|
||||
class codeGenTest extends CodeGeneratorTestSuite {
|
||||
|
||||
/* Wrapper for uniform testing */
|
||||
def test(input: os.Path): (Boolean, String) = {
|
||||
super.test(input, None, None)
|
||||
}
|
||||
|
||||
// IGNORE SET: Contains files/directories to ignore in testing
|
||||
val ignoreSet: Set[os.Path] = Set()
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with syntax errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING CODE GEN IN syntaxErr"
|
||||
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with semantic errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING CODE GEN IN semanticErr"
|
||||
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticArray))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticExit))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticIf))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticIO))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticRead))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticScope))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
|
||||
genTest("Code Generator", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing valid programs
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/basic/exit"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validBasicExit))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/basic/skip"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validBasicSkip))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/expressions"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validExpression))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/function/nested_functions"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/function/simple_functions"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/function/overload_functions"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/if"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIf))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/IO/print"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIOPrint))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/IO/read"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIORead))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/IO/special"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validIOSpecial))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/pairs"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validPair))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/runtimeErr/arrayOutOfBounds"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/runtimeErr/badChar"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEBadChar))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/runtimeErr/divideByZero"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/runtimeErr/integerOverflow"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/runtimeErr/nullDereference"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validRTENullDereference))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/scope"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validScope))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/sequence"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validSequence))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/variables"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validVariables))
|
||||
|
||||
behavior of "TESTING CODE GEN IN valid/while"
|
||||
genTest("Code Generator", ignoreSet, test, "succeed", os.list(validWhile))
|
||||
}
|
||||
194
src/test/wacc/codeGenTestConfig.scala
Normal file
194
src/test/wacc/codeGenTestConfig.scala
Normal file
@ -0,0 +1,194 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import parsley._
|
||||
import scala.util.{Try, Success, Failure}
|
||||
import wacc.backend.CodeGenerator
|
||||
import wacc.extension.Peephole
|
||||
import wacc.frontend.semantic.analyse
|
||||
import wacc.frontend.syntax.ast.WProgram
|
||||
import wacc.frontend.syntax.ImportHandler
|
||||
import java.time.Instant
|
||||
|
||||
// This is the trait for all code generating tests (CodeGen, Imports, Peephole)
|
||||
trait CodeGeneratorTestSuite extends ParserTestSuite {
|
||||
|
||||
/**
|
||||
* Returns the success of generating a wacc file
|
||||
*
|
||||
* @param input Path to wacc file
|
||||
* @return either (true, "") on a successful gen or (false, msg) on an unsuccessful gen
|
||||
*/
|
||||
def test(input: os.Path, peepholeOpt: Option[Peephole], importHandlerOpt: Option[ImportHandler], printParallelTests: Boolean = false): (Boolean, String) = {
|
||||
parseFile(input) match {
|
||||
case parsley.Success(prog: WProgram) => {
|
||||
|
||||
// Adds the program to be analysed based on whether the import handler exists or not
|
||||
val progToAnalyse = importHandlerOpt match {
|
||||
case Some(importHandler) =>
|
||||
if (printParallelTests) {
|
||||
var average = 0L
|
||||
val testNum = 100
|
||||
println(s"Starting Non-Parallel import processing for ${input.last} over $testNum tests")
|
||||
for (i <- 1 to testNum) {
|
||||
val start = Instant.now()
|
||||
importHandler.addImports(prog)
|
||||
val end = Instant.now()
|
||||
average += java.time.Duration.between(start, end).toNanos
|
||||
}
|
||||
println(s"Non-Parallel processing took: ${average/testNum} ns on average")
|
||||
average = 0L
|
||||
println(s"Starting Parallel import processing for ${input.last}")
|
||||
for (i <- 1 to testNum) {
|
||||
val start = Instant.now()
|
||||
importHandler.addImports(prog, optimiseFlag = true)
|
||||
val end = Instant.now()
|
||||
average += java.time.Duration.between(start, end).toNanos
|
||||
}
|
||||
println(s"Parallel processing took: ${average/testNum} ns")
|
||||
}
|
||||
|
||||
importHandler.addImports(prog, optimiseFlag = printParallelTests)
|
||||
case None => prog
|
||||
}
|
||||
|
||||
// Analyse the program and perform further checks after
|
||||
analyse(progToAnalyse) match {
|
||||
case Left(res) =>
|
||||
val fileName = input.last
|
||||
val outputFile = s"${fileName.replace(".wacc", ".s")}"
|
||||
val outputPath = os.pwd / "src" / "test" / "wacc" / "waccOutputs" / outputFile
|
||||
|
||||
// Make the outputPath's folder if it doesn't already exist
|
||||
os.makeDir.all(outputPath / os.up)
|
||||
val (prog, fEnv, mEnv) = res
|
||||
|
||||
// Create the code generator and peephole instances
|
||||
val codeGenerator = new CodeGenerator()
|
||||
|
||||
// Generate the asm without any possible peephole optimization
|
||||
val unOptimizedAsm = codeGenerator.generate(prog, fEnv, mEnv)
|
||||
|
||||
val asmCode = peepholeOpt match {
|
||||
case Some(peephole) => peephole.optimize(unOptimizedAsm).map(_.emit).mkString("\n")
|
||||
case None => unOptimizedAsm.map(_.emit).mkString("\n")
|
||||
}
|
||||
|
||||
// Writes to the folder location given that it doesn't exist or contains a different asm
|
||||
if (!os.exists(outputPath) || os.read(outputPath) != asmCode) {
|
||||
os.write.over(outputPath, asmCode)
|
||||
}
|
||||
|
||||
// Run checkAsm, then delete the folder containing the outputs regardless of the outcome
|
||||
val result =
|
||||
try {
|
||||
checkAsm(input, outputPath)
|
||||
} finally {
|
||||
os.remove.all(outputPath / os.up)
|
||||
}
|
||||
result
|
||||
|
||||
case Right(errs) =>
|
||||
val fileName = input.last
|
||||
(false, errs.map(err => err.getMessage(fileName, input.toString)).mkString("\n"))
|
||||
}
|
||||
}
|
||||
case parsley.Failure(msg) => (false, msg)
|
||||
}
|
||||
}
|
||||
|
||||
def checkAsm(waccPath: os.Path, asmPath: os.Path): (Boolean, String) = {
|
||||
// Read the WACC file lines
|
||||
val lines = os.read.lines(waccPath)
|
||||
|
||||
// Extract the expected exit code from a "# Exit:" block (default to 0 if not provided)
|
||||
val expectedExit: Int = {
|
||||
val maybeIndex = lines.indexWhere(_.trim == "# Exit:")
|
||||
if (maybeIndex != -1 && lines.size > maybeIndex + 1)
|
||||
Try(lines(maybeIndex + 1).stripPrefix("#").trim.toInt).getOrElse(0)
|
||||
else 0
|
||||
}
|
||||
|
||||
// Extract expected output lines from the "# Output:" block
|
||||
// We take all lines after "# Output:" until we reach a blank line
|
||||
val expectedOutput: List[String] = {
|
||||
val startIndex = lines.indexWhere(_.trim == "# Output:")
|
||||
if (startIndex != -1)
|
||||
lines.drop(startIndex + 1)
|
||||
.takeWhile(line => line.stripPrefix("#").trim.nonEmpty)
|
||||
.map(_.stripPrefix("#").trim)
|
||||
.toList
|
||||
else Nil
|
||||
}
|
||||
|
||||
// Extract input (if any) from a line starting with "# Input:"
|
||||
val maybeInput: Option[String] =
|
||||
lines.find(_.trim.startsWith("# Input:")).map(_.stripPrefix("# Input:").trim)
|
||||
|
||||
// Derive the assembly file name and the executable name
|
||||
val asmFileName = asmPath.last
|
||||
val executableName = asmFileName.stripSuffix(".s")
|
||||
|
||||
// Get the directory containing the assembly file (this will be the working directory)
|
||||
val asmDir = asmPath / os.up
|
||||
|
||||
// Build the run command using qemu-aarch64 properly
|
||||
val qemuString = s"qemu-aarch64 -L /usr/aarch64-linux-gnu/ ./$executableName"
|
||||
val runCommand = maybeInput match {
|
||||
case Some(input) =>
|
||||
// Escape any single quotes to avoid shell issues
|
||||
val safeInput = input.replace("'", "'\\''")
|
||||
s"echo \"$safeInput\" | $qemuString"
|
||||
case None =>
|
||||
s"echo '' | $qemuString"
|
||||
}
|
||||
|
||||
// Build the compile command using the cross compiler
|
||||
val compileCommand = s"aarch64-linux-gnu-gcc -o $executableName -z noexecstack -march=armv8-a $asmFileName"
|
||||
val execCommand = s"$compileCommand && $runCommand"
|
||||
|
||||
// Execute the command using the shell, setting cwd to the asm directory
|
||||
val resultTry = Try(os.proc("sh", "-c", execCommand)
|
||||
.call(cwd = asmDir, stdout = os.Pipe, stderr = os.Pipe, check = false))
|
||||
|
||||
resultTry match {
|
||||
case Failure(e) => (false, s"Exception while running command: ${e.getMessage}")
|
||||
case Success(result) =>
|
||||
val rawOutput = result.out.text().trim.linesIterator.toList
|
||||
val actualOutput = normaliseOutput(rawOutput)
|
||||
val actualExit = result.exitCode
|
||||
var errors = List.empty[String]
|
||||
if (actualExit != expectedExit) {
|
||||
errors ::= s"Expected exit code $expectedExit but got $actualExit."
|
||||
}
|
||||
if (actualOutput != expectedOutput) {
|
||||
errors ::= s"Expected output:\n${expectedOutput.mkString("\n")}\nBut got:\n${actualOutput.mkString("\n")}"
|
||||
}
|
||||
if (errors.isEmpty) {
|
||||
(true, "")
|
||||
} else {
|
||||
(false, errors.reverse.mkString("\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Normalises the output after executing generated code
|
||||
* replacing addresses with #addrs#
|
||||
* replacing runtime_errors with #runtime_error#
|
||||
*
|
||||
* Returns the normalised Output as a list of strings
|
||||
*/
|
||||
def normaliseOutput(rawOutput: List[String]): List[String] = {
|
||||
val addressPattern = "0x[0-9A-Fa-f]+".r
|
||||
val rtePattern = "^fatal error:.*$".r
|
||||
|
||||
val outputWithAddrs = rawOutput.map { line =>
|
||||
addressPattern.replaceAllIn(line, "#addrs#")
|
||||
}
|
||||
|
||||
val outputWithRTEs = outputWithAddrs.map { line =>
|
||||
rtePattern.replaceAllIn(line, "#runtime_error#")
|
||||
}
|
||||
|
||||
outputWithRTEs
|
||||
}
|
||||
|
||||
}
|
||||
121
src/test/wacc/imports.scala
Normal file
121
src/test/wacc/imports.scala
Normal file
@ -0,0 +1,121 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import wacc.frontend.syntax.ImportHandler
|
||||
|
||||
class importsTest extends CodeGeneratorTestSuite {
|
||||
|
||||
/* Wrapper for uniform testing */
|
||||
def test(input: os.Path): (Boolean, String) = {
|
||||
|
||||
// Create the importHandler instance
|
||||
val importHandler = ImportHandler()
|
||||
super.test(input, None, Some(importHandler))
|
||||
}
|
||||
|
||||
// IGNORE SET: Contains files/directories to ignore in testing
|
||||
val ignoreSet: Set[os.Path] = Set()
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with syntax errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN syntaxErr"
|
||||
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with semantic errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN semanticErr"
|
||||
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticArray))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticExit))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticIf))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticIO))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticRead))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticScope))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
|
||||
genTest("Import Handler", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing valid programs
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/basic/exit"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validBasicExit))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/basic/skip"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validBasicSkip))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/expressions"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validExpression))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/function/nested_functions"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/function/simple_functions"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/function/overload_functions"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/function/imported_functions"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionImports))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/if"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIf))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/IO/print"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIOPrint))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/IO/read"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIORead))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/IO/special"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validIOSpecial))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/pairs"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validPair))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/arrayOutOfBounds"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/badChar"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEBadChar))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/divideByZero"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/integerOverflow"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/runtimeErr/nullDereference"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validRTENullDereference))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/scope"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validScope))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/sequence"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validSequence))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/variables"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validVariables))
|
||||
|
||||
behavior of "TESTING IMPORT HANDLER IN valid/while"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validWhile))
|
||||
}
|
||||
23
src/test/wacc/importsParallel.scala
Normal file
23
src/test/wacc/importsParallel.scala
Normal file
@ -0,0 +1,23 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import wacc.frontend.syntax.ImportHandler
|
||||
|
||||
class importsParallelTest extends CodeGeneratorTestSuite {
|
||||
|
||||
/* Wrapper for uniform testing */
|
||||
def test(input: os.Path): (Boolean, String) = {
|
||||
|
||||
// Create the importHandler instance
|
||||
val importHandler = ImportHandler()
|
||||
super.test(input, None, Some(importHandler), printParallelTests = false)
|
||||
}
|
||||
|
||||
// IGNORE SET: Contains files/directories to ignore in testing
|
||||
val ignoreSet: Set[os.Path] = Set()
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with syntax errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PARALLEL IMPORT HANDLER IN valid/function/imported_functions"
|
||||
genTest("Import Handler", ignoreSet, test, "succeed", os.list(validFunctionImports))
|
||||
}
|
||||
171
src/test/wacc/parser.scala
Normal file
171
src/test/wacc/parser.scala
Normal file
@ -0,0 +1,171 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import parsley.{Success, Failure}
|
||||
|
||||
class ParserTests extends ParserTestSuite {
|
||||
|
||||
/**
|
||||
* Returns the success of parsing a wacc file
|
||||
*
|
||||
* @param input Path to wacc file
|
||||
* @return either (true, "") on a successful parse or (false, msg) on an unsuccessful parse
|
||||
*/
|
||||
def test(input: os.Path): (Boolean, String) = {
|
||||
parseFile(input) match {
|
||||
case Success(_) => (true, "")
|
||||
case Failure(msg) => (false, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// IGNORE SET: Contains files/directories to ignore in testing
|
||||
val ignoreSet: Set[os.Path] = Set()
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with syntax errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/array"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/basic"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/expressions"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/function"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/if"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/literals"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/pairs"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/print"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/sequence"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/variables"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
|
||||
|
||||
behavior of "TESTING PARSER IN syntaxErr/while"
|
||||
genTest("Parser", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing syntatically valid programs
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PARSER IN valid/advanced"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validAdvanced))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/array"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validArray))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/basic/exit"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validBasicExit))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/basic/skip"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validBasicSkip))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/expressions"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validExpression))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/function/nested_functions"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/function/simple_functions"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/function/overload_functions"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/if"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validIf))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/IO/print"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validIOPrint))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/IO/read"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validIORead))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/IO/special"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validIOSpecial))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/pairs"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validPair))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/runtimeErr/arrayOutOfBounds"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/runtimeErr/badChar"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEBadChar))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/runtimeErr/divideByZero"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/runtimeErr/integerOverflow"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/runtimeErr/nullDereference"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validRTENullDereference))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/scope"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validScope))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/sequence"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validSequence))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/variables"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validVariables))
|
||||
|
||||
behavior of "TESTING PARSER IN valid/while"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(validWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing syntatically valid programs (These are semantically invalid)
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/array"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticArray))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/exit"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticExit))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/expressions"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticExpressions))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/function"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticFunction))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/if"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticIf))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/IO"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticIO))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/multiple"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticMultiple))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/pairs"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticPairs))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/print"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticPrint))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/read"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticRead))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/scope"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticScope))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/variables"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticVariables))
|
||||
|
||||
behavior of "TESTING PARSER IN semanticErr/while"
|
||||
genTest("Parser", ignoreSet, test, "succeed", os.list(invalidSemanticWhile))
|
||||
}
|
||||
119
src/test/wacc/peephole.scala
Normal file
119
src/test/wacc/peephole.scala
Normal file
@ -0,0 +1,119 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import wacc.extension.Peephole
|
||||
|
||||
// These tests are here to ensure that all previous tests still pass after the peephole optimization is applied
|
||||
class peepholeTests extends CodeGeneratorTestSuite {
|
||||
|
||||
/* Wrapper for uniform testing */
|
||||
def test(input: os.Path): (Boolean, String) = {
|
||||
|
||||
// Create the peephole instance
|
||||
val peephole = Peephole()
|
||||
super.test(input, Some(peephole), None)
|
||||
}
|
||||
|
||||
// IGNORE SET: Contains files/directories to ignore in testing
|
||||
val ignoreSet: Set[os.Path] = Set()
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with syntax errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN syntaxErr"
|
||||
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxArray))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxBasic))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxExpressions))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxFunction))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxIf))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxLiteral))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxPairs))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxPrint))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxSequence))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxVariables))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSyntaxWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with semantic errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN semanticErr"
|
||||
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticArray))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticExit))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticIf))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticIO))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticRead))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticScope))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
|
||||
genTest("Peephole", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing valid programs
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/basic/exit"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validBasicExit))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/basic/skip"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validBasicSkip))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/expressions"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validExpression))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/function/nested_functions"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/function/simple_functions"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/function/overload_functions"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/if"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIf))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/IO/print"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIOPrint))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/IO/read"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIORead))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/IO/special"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validIOSpecial))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/pairs"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validPair))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/arrayOutOfBounds"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/badChar"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEBadChar))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/divideByZero"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/integerOverflow"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/runtimeErr/nullDereference"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validRTENullDereference))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/scope"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validScope))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/sequence"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validSequence))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/variables"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validVariables))
|
||||
|
||||
behavior of "TESTING PEEPHOLE IN valid/while"
|
||||
genTest("Peephole", ignoreSet, test, "succeed", os.list(validWhile))
|
||||
}
|
||||
143
src/test/wacc/semanticAnalyser.scala
Normal file
143
src/test/wacc/semanticAnalyser.scala
Normal file
@ -0,0 +1,143 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import parsley.{Success, Failure}
|
||||
import wacc.frontend.syntax.ast.WProgram
|
||||
import wacc.frontend.semantic.analyse
|
||||
|
||||
class semanticAnalyserTests extends ParserTestSuite {
|
||||
|
||||
/**
|
||||
* Returns the success of parsing a wacc file
|
||||
*
|
||||
* @param input Path to wacc file
|
||||
* @return either (true, "") on a successful analysis or (false, msg) on an unsuccessful analysis
|
||||
*/
|
||||
def test(input: os.Path): (Boolean, String) = {
|
||||
parseFile(input) match {
|
||||
case Success(x: WProgram) => {
|
||||
analyse(x) match {
|
||||
case Left(prog) => (true, "")
|
||||
case Right(errs) =>
|
||||
val fileName = input.last
|
||||
(false, errs.map(err => err.getMessage(fileName, input.toString)).mkString("\n"))
|
||||
}
|
||||
}
|
||||
case Failure(msg) => (false, "OOPS. Looks like the parser failed :/")
|
||||
}
|
||||
}
|
||||
|
||||
// IGNORE SET: Contains files/directories to ignore in testing
|
||||
val ignoreSet: Set[os.Path] = Set()
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing programs with semantic errors
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/array"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticArray))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/exit"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticExit))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/expressions"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticExpressions))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/function"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticFunction))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/if"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticIf))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/IO"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticIO))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/multiple"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticMultiple))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/pairs"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticPairs))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/print"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticPrint))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/read"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticRead))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/scope"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticScope))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/variables"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticVariables))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN semanticErr/while"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "fail", os.list(invalidSemanticWhile))
|
||||
|
||||
// ----------------------------------------------------
|
||||
// Testing semantically valid programs
|
||||
// ----------------------------------------------------
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/advanced"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validAdvanced))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/array"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validArray))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/basic/exit"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validBasicExit))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/basic/skip"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validBasicSkip))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/expressions"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validExpression))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/function/nested_functions"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validFunctionNestFuns))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/function/simple_functions"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validFunctionSimpFuns))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/function/overload_functions"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validFunctionOverFuns))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/if"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIf))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/IO/print"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIOPrint))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/IO/read"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIORead))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/IO/special"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validIOSpecial))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/pairs"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validPair))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/arrayOutOfBounds"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEArrOOB))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/badChar"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEBadChar))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/divideByZero"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEDivByZero))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/integerOverflow"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTEIntOverflow))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/runtimeErr/nullDereference"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validRTENullDereference))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/scope"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validScope))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/sequence"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validSequence))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/variables"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validVariables))
|
||||
|
||||
behavior of "TESTING SEMANTIC ANALYSER IN valid/while"
|
||||
genTest("Semantic Analyser", ignoreSet, test, "succeed", os.list(validWhile))
|
||||
}
|
||||
131
src/test/wacc/testConfig.scala
Normal file
131
src/test/wacc/testConfig.scala
Normal file
@ -0,0 +1,131 @@
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import parsley.Result
|
||||
import wacc.frontend.syntax.{parser, ast}
|
||||
import wacc.frontend.syntax.ast._
|
||||
|
||||
// ----------------------------------------------------
|
||||
// SETUP WITH PATHS TO THE DIFFERENT DIRECTORIES
|
||||
// ----------------------------------------------------
|
||||
|
||||
val basePath = os.pwd / "src" / "test" / "wacc" / "waccPrograms"
|
||||
val invalidSyntaxPath = basePath / "syntaxErr"
|
||||
val invalidSemanticPath = basePath / "semanticErr"
|
||||
val validPath = basePath / "valid"
|
||||
|
||||
// Invalid Syntax test paths
|
||||
val invalidSyntaxArray = invalidSyntaxPath / "array"
|
||||
val invalidSyntaxBasic = invalidSyntaxPath / "basic"
|
||||
val invalidSyntaxExpressions = invalidSyntaxPath / "expressions"
|
||||
val invalidSyntaxFunction = invalidSyntaxPath / "function"
|
||||
val invalidSyntaxIf = invalidSyntaxPath / "if"
|
||||
val invalidSyntaxLiteral = invalidSyntaxPath / "literals"
|
||||
val invalidSyntaxPairs = invalidSyntaxPath / "pairs"
|
||||
val invalidSyntaxPrint = invalidSyntaxPath / "print"
|
||||
val invalidSyntaxSequence = invalidSyntaxPath / "sequence"
|
||||
val invalidSyntaxVariables = invalidSyntaxPath / "variables"
|
||||
val invalidSyntaxWhile = invalidSyntaxPath / "while"
|
||||
|
||||
// Invalid Semantic test paths
|
||||
val invalidSemanticArray = invalidSemanticPath / "array"
|
||||
val invalidSemanticExit = invalidSemanticPath / "exit"
|
||||
val invalidSemanticExpressions = invalidSemanticPath / "expressions"
|
||||
val invalidSemanticFunction = invalidSemanticPath / "function"
|
||||
val invalidSemanticIf = invalidSemanticPath / "if"
|
||||
val invalidSemanticIO = invalidSemanticPath / "IO"
|
||||
val invalidSemanticMultiple = invalidSemanticPath / "multiple"
|
||||
val invalidSemanticPairs = invalidSemanticPath / "pairs"
|
||||
val invalidSemanticPrint = invalidSemanticPath / "print"
|
||||
val invalidSemanticRead = invalidSemanticPath / "read"
|
||||
val invalidSemanticScope = invalidSemanticPath / "scope"
|
||||
val invalidSemanticVariables = invalidSemanticPath / "variables"
|
||||
val invalidSemanticWhile = invalidSemanticPath / "while"
|
||||
|
||||
// Valid test paths
|
||||
val validAdvanced = validPath / "advanced"
|
||||
val validArray = validPath / "array"
|
||||
val validBasicExit = validPath / "basic" / "exit"
|
||||
val validBasicSkip = validPath / "basic" / "skip"
|
||||
val validExpression = validPath / "expressions"
|
||||
val validFunctionNestFuns = validPath / "function" / "nested_functions"
|
||||
val validFunctionSimpFuns = validPath / "function" / "simple_functions"
|
||||
val validFunctionOverFuns = validPath / "function" / "overload_functions"
|
||||
val validFunctionImports = validPath / "function" / "imported_functions"
|
||||
val validIf = validPath / "if"
|
||||
val validIOSpecial = validPath / "IO" / "special"
|
||||
val validIOPrint = validPath / "IO" / "print"
|
||||
val validIORead = validPath / "IO" / "read"
|
||||
val validPair = validPath / "pairs"
|
||||
val validRTEArrOOB = validPath / "runtimeErr" / "arrayOutOfBounds"
|
||||
val validRTEBadChar = validPath / "runtimeErr" / "badChar"
|
||||
val validRTEDivByZero = validPath / "runtimeErr" / "divideByZero"
|
||||
val validRTEIntOverflow = validPath / "runtimeErr" / "integerOverflow"
|
||||
val validRTENullDereference = validPath / "runtimeErr" / "nullDereference"
|
||||
val validScope = validPath / "scope"
|
||||
val validSequence = validPath / "sequence"
|
||||
val validVariables = validPath / "variables"
|
||||
val validWhile = validPath / "while"
|
||||
|
||||
// ----------------------------------------------------
|
||||
// HELPER FUNCTIONS
|
||||
// ----------------------------------------------------
|
||||
|
||||
trait ParserTestSuite extends AnyFlatSpec {
|
||||
|
||||
/**
|
||||
* Returns parsed wacc file
|
||||
*
|
||||
* @param input Path to wacc file
|
||||
* @return parsed wacc program
|
||||
*/
|
||||
def parseFile(input: os.Path): Result[String, WProgram] = {
|
||||
val program = os.read(input)
|
||||
parser.parse(program)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates test cases for each file in `files`
|
||||
* If the file (or directory name) appears in `dontIgnoreSet` or doesn't appear in `ignoreSet`,
|
||||
* then the test will be ran.
|
||||
*
|
||||
* @param analyser
|
||||
* @param testFunc
|
||||
* @param res String indicating if the test should "fail" or "succeed"
|
||||
* @param files Sequence of files within the directory of interest
|
||||
*/
|
||||
def genTest(analyser: String,
|
||||
ignoreSet: Set[os.Path],
|
||||
testFunc: (os.Path) => (Boolean, String),
|
||||
res: String,
|
||||
files: IndexedSeq[os.Path]): Unit = {
|
||||
|
||||
for (file <- files) {
|
||||
val fileName = file.last
|
||||
|
||||
// --------------------------
|
||||
// IGNORING TESTS
|
||||
// --------------------------
|
||||
if (ignoreSet.exists(file.startsWith)) {
|
||||
ignore should s"$res with $fileName" in {
|
||||
// *** test is ignored ***
|
||||
}
|
||||
|
||||
// --------------------------
|
||||
// TESTING REMAINING TESTS
|
||||
// --------------------------
|
||||
} else {
|
||||
val (verdict, msg) = testFunc(file)
|
||||
val errorMessage = s"Expected $analyser to $res. $analyser returns:\n$msg"
|
||||
|
||||
if (res == "fail") {
|
||||
it should s"$res with $fileName" in {
|
||||
assert(verdict == false, errorMessage)
|
||||
}
|
||||
} else {
|
||||
it should s"$res with $fileName" in {
|
||||
assert(verdict == true, errorMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
src/test/wacc/waccPrograms/semanticErr/IO/readTypeErr.wacc
Normal file
15
src/test/wacc/waccPrograms/semanticErr/IO/readTypeErr.wacc
Normal file
@ -0,0 +1,15 @@
|
||||
# attempt to read into a boolean variable
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
bool b = true ;
|
||||
read b ;
|
||||
println b
|
||||
end
|
||||
@ -0,0 +1,16 @@
|
||||
# Attempting to access an array with an invalid complex expression
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] a = [1, 2];
|
||||
int b = a[1 - "horse"];
|
||||
int c = a[!false]
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# Attempting to access an array with a non-integer index.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] a = [1, 2];
|
||||
int b = a["horse"]
|
||||
end
|
||||
@ -0,0 +1,20 @@
|
||||
# Indexing an array to get sub-arrays, but going too far.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] a = [1, 2];
|
||||
int[] b = [3, 4];
|
||||
|
||||
int[][] ab = [a, b];
|
||||
|
||||
int[] sameAsA = ab[0];
|
||||
int oops = sameAsA[0][1]
|
||||
end
|
||||
15
src/test/wacc/waccPrograms/semanticErr/array/badIndex.wacc
Normal file
15
src/test/wacc/waccPrograms/semanticErr/array/badIndex.wacc
Normal file
@ -0,0 +1,15 @@
|
||||
# Too much indexing!
|
||||
# Thanks to David Pan
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] a = [1];
|
||||
println a[1][2]
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# Attempting to array-index an undefined identifier.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
char c = horse[2]
|
||||
end
|
||||
@ -0,0 +1,19 @@
|
||||
# Attempting to mix types in an array literal.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
char five() is
|
||||
begin return '5' end
|
||||
end
|
||||
|
||||
char f = call five();
|
||||
int[] a = [1, f]
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# Arrays should not be covariant
|
||||
# Thanks to Nathaniel Burke
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
char[][] acs = [] ;
|
||||
string[] bad = acs
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# It shouldn't be possible to index strings (wrong side of the type relaxation)
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
string str = "hello world";
|
||||
char x = str[0]
|
||||
end
|
||||
@ -0,0 +1,16 @@
|
||||
# Trying to assign the wrong array dimension to a variable, or non-matching pairs.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] a = [1, 2];
|
||||
int[][] aa = [a, a];
|
||||
int[] b = aa
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# Accessing too many array indices.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] a = [1, 2];
|
||||
int oops = a[1][2][3]
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# Array was given wrong type.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] should_be_int_array = ['a', 'b', 'c']
|
||||
end
|
||||
13
src/test/wacc/waccPrograms/semanticErr/exit/badCharExit.wacc
Normal file
13
src/test/wacc/waccPrograms/semanticErr/exit/badCharExit.wacc
Normal file
@ -0,0 +1,13 @@
|
||||
# tries to exit using a character
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
exit 'a'
|
||||
end
|
||||
14
src/test/wacc/waccPrograms/semanticErr/exit/exitNonInt.wacc
Normal file
14
src/test/wacc/waccPrograms/semanticErr/exit/exitNonInt.wacc
Normal file
@ -0,0 +1,14 @@
|
||||
# exit with non-int - this should be an invalid program!
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
char x = 'f' ;
|
||||
exit x
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# trying to return from the main program
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
return 42 ;
|
||||
println "should not get here"
|
||||
end
|
||||
@ -0,0 +1,26 @@
|
||||
# Returning from the main body is forbidden.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
while true do
|
||||
return 3
|
||||
done;
|
||||
|
||||
if true then
|
||||
return 4
|
||||
else
|
||||
return 5
|
||||
fi;
|
||||
|
||||
begin
|
||||
return 6
|
||||
end
|
||||
end
|
||||
@ -0,0 +1,13 @@
|
||||
# expresission type mismatch int->bool
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
bool b = 1 || 1
|
||||
end
|
||||
@ -0,0 +1,13 @@
|
||||
# expresission type mismatch bool->int
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int b = 15 + 6 || 19
|
||||
end
|
||||
@ -0,0 +1,13 @@
|
||||
# expresission type mismatch bool->int
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int x = true + false
|
||||
end
|
||||
@ -0,0 +1,17 @@
|
||||
# evaluating less-than on references
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(int,int) x = newpair(1,2) ;
|
||||
println x;
|
||||
pair(int,int) y = newpair(2,3) ;
|
||||
println y;
|
||||
println x < y
|
||||
end
|
||||
@ -0,0 +1,13 @@
|
||||
# expresission type mismatch bool->int
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int b = 1 + 2 + true + 4 + 5
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# program performs boolean operations on arrays
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int[] x = [1,2];
|
||||
int[] y = [3,4];
|
||||
println x > y
|
||||
end
|
||||
@ -0,0 +1,16 @@
|
||||
# element access is not permitted for strings
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
string str = "hello world!" ;
|
||||
println str ;
|
||||
str[0] = 'H' ;
|
||||
println str
|
||||
end
|
||||
@ -0,0 +1,23 @@
|
||||
# Two overloads with different pair parameters when called with a
|
||||
# `null` directly leads to ambiguity.
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(pair(int, char) x) is
|
||||
return 1
|
||||
end
|
||||
|
||||
int f(pair(int, bool) x) is
|
||||
return 2
|
||||
end
|
||||
|
||||
int result = call f(null) ;
|
||||
println result
|
||||
end
|
||||
@ -0,0 +1,25 @@
|
||||
# Overloads "fun" in a way that calling with a char array
|
||||
# Could match multiple signatures (char[] or string)
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int fun(string x) is
|
||||
int y = 1;
|
||||
return y
|
||||
end
|
||||
|
||||
int fun(char[] x) is
|
||||
return len x
|
||||
end
|
||||
|
||||
char[] text = ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'] ;
|
||||
int result = call fun(text) ;
|
||||
println result
|
||||
end
|
||||
@ -0,0 +1,27 @@
|
||||
# Two overloads with identical parameter lists but different return types
|
||||
# leads to ambiguity.
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int x) is
|
||||
return x
|
||||
end
|
||||
|
||||
bool f(int x) is
|
||||
return x > 0
|
||||
end
|
||||
|
||||
bool result = call f(10) ;
|
||||
if result then
|
||||
println "Positive"
|
||||
else
|
||||
println "Non-positive"
|
||||
fi
|
||||
end
|
||||
@ -0,0 +1,24 @@
|
||||
# Defines overloaded "f" for int and bool, but calls with a char.
|
||||
# There's no matching overload that takes a char.
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int x) is
|
||||
return x + 1
|
||||
end
|
||||
|
||||
bool f(bool x) is
|
||||
return !x
|
||||
end
|
||||
|
||||
# Attempt to call f with a char argument -> no matching overload
|
||||
int test = call f('z') ;
|
||||
println test
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# Calling an undefined identifier.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int result = call fib(5)
|
||||
end
|
||||
@ -0,0 +1,17 @@
|
||||
# functions cannot define the same argument twice
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int foo(int x, int x) is
|
||||
return x
|
||||
end
|
||||
|
||||
skip
|
||||
end
|
||||
@ -0,0 +1,20 @@
|
||||
# functions cannot access global variables
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
x = -1 ;
|
||||
return 0
|
||||
end
|
||||
int x = 5 ;
|
||||
int y = call f() ;
|
||||
println x
|
||||
end
|
||||
|
||||
@ -0,0 +1,17 @@
|
||||
# tries to assign to a function
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
return 3
|
||||
end
|
||||
|
||||
f = 2
|
||||
end
|
||||
@ -0,0 +1,18 @@
|
||||
# function parameter misuse
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int x) is
|
||||
bool b = x && true ;
|
||||
return 0
|
||||
end
|
||||
int x = call f(0)
|
||||
end
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
# function call type mismatch
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
return 0
|
||||
end
|
||||
bool b = call f() ;
|
||||
println b
|
||||
end
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
# function call type mismatch for both functions
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
return 0
|
||||
end
|
||||
|
||||
bool f() is
|
||||
return false
|
||||
end
|
||||
|
||||
string s = call f() ;
|
||||
println s
|
||||
end
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
# function parameter type mismatch
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int x) is
|
||||
return x
|
||||
end
|
||||
bool b = true ;
|
||||
int x = call f(b)
|
||||
end
|
||||
|
||||
@ -0,0 +1,24 @@
|
||||
# function parameter type mismatch for either function
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(pair(int, int) x) is
|
||||
int y = fst x ;
|
||||
return y
|
||||
end
|
||||
|
||||
char f(int x) is
|
||||
return chr x
|
||||
end
|
||||
|
||||
bool b = true ;
|
||||
int x = call f(b)
|
||||
end
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
# function return type mismatch: int <- char
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
return 'c'
|
||||
end
|
||||
int x = call f() ;
|
||||
println x
|
||||
end
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
# function return type mismatch: int <- char even though there's an overloaded function with the correct return
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
char f() is
|
||||
return 'c'
|
||||
end
|
||||
|
||||
int f() is
|
||||
return 'c'
|
||||
end
|
||||
|
||||
int x = call f() ;
|
||||
println x
|
||||
end
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
# function return type mismatch: char <- int even though there's an overloaded function with the correct return
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
char f() is
|
||||
return 1
|
||||
end
|
||||
|
||||
int f() is
|
||||
return 'a'
|
||||
end
|
||||
|
||||
char x = call f() ;
|
||||
println x
|
||||
end
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
# function call with too many arguments
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int a, int b) is
|
||||
return 0
|
||||
end
|
||||
int x = call f(1,2,3);
|
||||
println x
|
||||
end
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
# function call with more arguments than either function
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int a) is
|
||||
return 0
|
||||
end
|
||||
|
||||
int f(int a, int b) is
|
||||
return 0
|
||||
end
|
||||
|
||||
int x = call f(1,2,3);
|
||||
println x
|
||||
end
|
||||
|
||||
@ -0,0 +1,20 @@
|
||||
# attempted redefinition of function of the same signature
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
return 0
|
||||
end
|
||||
int f() is
|
||||
return 1
|
||||
end
|
||||
int x = call f();
|
||||
println x
|
||||
end
|
||||
@ -0,0 +1,20 @@
|
||||
# attempted redefinition of function of the same signature
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int a, string b, char c, bool d, pair(int, int) e, int[] f) is
|
||||
return 0
|
||||
end
|
||||
int f(int a, string b, char c, bool d, pair(int, int) e, int[] f) is
|
||||
return 1
|
||||
end
|
||||
|
||||
skip
|
||||
end
|
||||
@ -0,0 +1,19 @@
|
||||
# function call with arguments swapped
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int a, bool b) is
|
||||
return 0
|
||||
end
|
||||
int x = call f(true,1);
|
||||
println x
|
||||
end
|
||||
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
# function call with too few arguments
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(int a, int b) is
|
||||
return 0
|
||||
end
|
||||
int x = call f(1);
|
||||
println x
|
||||
end
|
||||
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
# Trying to return something invalid from a branch.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
if true then
|
||||
return horse
|
||||
else
|
||||
return 1
|
||||
fi
|
||||
end
|
||||
|
||||
int result = call f()
|
||||
end
|
||||
@ -0,0 +1,27 @@
|
||||
# Trying to return two different types from a function.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(bool b) is
|
||||
if b then
|
||||
return 'c';
|
||||
while true do
|
||||
return 'd'
|
||||
done
|
||||
else
|
||||
skip
|
||||
fi;
|
||||
|
||||
return 10
|
||||
end
|
||||
|
||||
int a = call f(false)
|
||||
end
|
||||
@ -0,0 +1,17 @@
|
||||
# if condition type mismatch
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
if 15 + 6 then
|
||||
skip
|
||||
else
|
||||
skip
|
||||
fi
|
||||
end
|
||||
@ -0,0 +1,25 @@
|
||||
# a complete mess of function definition and use
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f(bool b) is
|
||||
x = -1 ;
|
||||
int x = 1;
|
||||
bool x = 3;
|
||||
return "done"
|
||||
end
|
||||
|
||||
int x = 5 ;
|
||||
int y = call f(x) ;
|
||||
println y ;
|
||||
println x
|
||||
|
||||
end
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
# boolean typo and if condition type mismatch
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
while tru do
|
||||
if 15 + 6 then
|
||||
skip
|
||||
else
|
||||
skip
|
||||
fi
|
||||
done
|
||||
end
|
||||
@ -0,0 +1,16 @@
|
||||
# long expression with multiple type mismatches
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
|
||||
int x = ((1 + true) * (2 - false)) || 17 ;
|
||||
exit x
|
||||
|
||||
end
|
||||
@ -0,0 +1,20 @@
|
||||
# variable names are case sensitive in all possible ways
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int number = 42 ;
|
||||
println NUMBER ;
|
||||
|
||||
int INDEX = 0 ;
|
||||
println index ;
|
||||
|
||||
int miXed = 3 ;
|
||||
println MIxED
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# multiple type mismatches: int <- bool, bool <- char, char <- int
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int i = true ;
|
||||
bool b = 'a' ;
|
||||
char c = 10
|
||||
end
|
||||
@ -0,0 +1,44 @@
|
||||
# Trying to obfuscate invalid returns with whiles.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int f() is
|
||||
if true then
|
||||
while true do
|
||||
return 'a'
|
||||
done
|
||||
else
|
||||
return 5
|
||||
fi;
|
||||
if false then
|
||||
return 2
|
||||
else
|
||||
while false do
|
||||
return 'b'
|
||||
done
|
||||
fi;
|
||||
if true then
|
||||
while true do
|
||||
return -2
|
||||
done
|
||||
else
|
||||
while false do
|
||||
return -4
|
||||
done
|
||||
fi;
|
||||
while false do
|
||||
return !"horse"
|
||||
done;
|
||||
exit -1
|
||||
end
|
||||
|
||||
int i = call f()
|
||||
end
|
||||
@ -0,0 +1,13 @@
|
||||
# newpair cannot be assigned to non-pair type
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int x = newpair(10, 20)
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# Assignment is not legal when both sides types are not known
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(int, int) p = newpair(4, 5);
|
||||
pair(pair, int) q = newpair(p, 6);
|
||||
fst fst q = snd fst q
|
||||
end
|
||||
@ -0,0 +1,16 @@
|
||||
# call free on a non-pair type
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int x = 5;
|
||||
println x;
|
||||
free 9 ;
|
||||
println x
|
||||
end
|
||||
@ -0,0 +1,13 @@
|
||||
# newpair must match the underlying type of the pair
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(char, bool) x = newpair(10, 20)
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# Pairs should not be covariant
|
||||
# Thanks to Nathaniel Burke
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(char[], char[]) pcs = null ;
|
||||
pair(string, string) bad = pcs
|
||||
end
|
||||
@ -0,0 +1,16 @@
|
||||
# Trying to assign non-matching pairs to eachother.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(int, int) p1 = newpair(0, 0);
|
||||
pair(char, char) p2 = newpair('a', 'a');
|
||||
p2 = p1
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# Reading is not legal when the type is not known
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(pair, int) p = null;
|
||||
read fst fst p
|
||||
end
|
||||
@ -0,0 +1,18 @@
|
||||
# Trying to place an incorrect type into a parameterless pair.
|
||||
# Thanks to Ethan Range, Fawwaz Abdullah, Robbie Buxton, and Edward Hartley
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int c = 5;
|
||||
pair(int, int) p = newpair(c, c);
|
||||
pair(pair, int) oops = newpair(p, 0);
|
||||
|
||||
fst oops = c
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# type mismatch: int <- char
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
int x = 4 ;
|
||||
char y = 'a' ;
|
||||
println x + y
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# If we are reading into a pair projection, the element must be readable
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(bool, bool) p = null;
|
||||
read fst p
|
||||
end
|
||||
@ -0,0 +1,14 @@
|
||||
# If we are reading into a pair projection, the element must be readable
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(bool, bool) p = null;
|
||||
read snd p
|
||||
end
|
||||
@ -0,0 +1,15 @@
|
||||
# read into pair type
|
||||
|
||||
# Output:
|
||||
# #semantic_error#
|
||||
|
||||
# Exit:
|
||||
# 200
|
||||
|
||||
# Program:
|
||||
|
||||
begin
|
||||
pair(int, int) p = newpair(1,2);
|
||||
read p
|
||||
end
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user