Skip to content

Commit

Permalink
working gridworld
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric authored and Eric committed Dec 20, 2018
1 parent 1b54221 commit f331f84
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions src/main/scala/discrete/MDP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ object MDP {
var rewardGrid = Array[Array[Double]]() //fixed
var discountFactor = 0.8 //fixed
var transitionProbabilities = Array[Array[Double]]() //fixed, assume fixed through out
var terminalStatesIndex = Array[Tuple2[Int, Int]]()

def getValueGrid(): Array[Array[Double]] = {
this.valueGrid
Expand All @@ -41,6 +42,10 @@ object MDP {
this.discountFactor = dFactor
}

def setTerminalStatesIndex(tStates: Array[Tuple2[Int, Int]]): Unit = {
this.terminalStatesIndex = tStates
}

def getAdjIndices(row: Int, col: Int): Array[Tuple2[Int, Int]] = {

if(rewardGrid.length < 1) return(Array[Tuple2[Int, Int]]())
Expand All @@ -67,46 +72,54 @@ object MDP {

if (dir == "N"){
if (adjInd contains (row - 1, col)) dirInd = dirInd ++ Array((0.8, (row - 1, col))) else dirInd = dirInd ++ Array((0.8, (row, col)))
if (adjInd contains (row, col + 1)) dirInd = dirInd ++ Array((0.2, (row, col + 1))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row, col - 1)) dirInd = dirInd ++ Array((0.2, (row, col - 1))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row, col + 1)) dirInd = dirInd ++ Array((0.1, (row, col + 1))) else dirInd = dirInd ++ Array((0.1, (row, col)))
if (adjInd contains (row, col - 1)) dirInd = dirInd ++ Array((0.1, (row, col - 1))) else dirInd = dirInd ++ Array((0.1, (row, col)))
} else if (dir == "S"){
if (adjInd contains (row + 1, col)) dirInd = dirInd ++ Array((0.8, (row + 1, col))) else dirInd = dirInd ++ Array((0.8, (row, col)))
if (adjInd contains (row, col + 1)) dirInd = dirInd ++ Array((0.2, (row, col + 1))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row, col - 1)) dirInd = dirInd ++ Array((0.2, (row, col - 1))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row, col + 1)) dirInd = dirInd ++ Array((0.1, (row, col + 1))) else dirInd = dirInd ++ Array((0.1, (row, col)))
if (adjInd contains (row, col - 1)) dirInd = dirInd ++ Array((0.1, (row, col - 1))) else dirInd = dirInd ++ Array((0.1, (row, col)))
} else if (dir == "E") {
if (adjInd contains (row, col + 1)) dirInd = dirInd ++ Array((0.8, (row, col + 1))) else dirInd = dirInd ++ Array((0.8, (row, col)))
if (adjInd contains (row + 1, col)) dirInd = dirInd ++ Array((0.2, (row + 1, col))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row - 1, col)) dirInd = dirInd ++ Array((0.2, (row - 1, col))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row + 1, col)) dirInd = dirInd ++ Array((0.1, (row + 1, col))) else dirInd = dirInd ++ Array((0.1, (row, col)))
if (adjInd contains (row - 1, col)) dirInd = dirInd ++ Array((0.1, (row - 1, col))) else dirInd = dirInd ++ Array((0.1, (row, col)))
} else if (dir == "W") {
if (adjInd contains (row, col - 1)) dirInd = dirInd ++ Array((0.8, (row, col - 1))) else dirInd = dirInd ++ Array((0.8, (row, col)))
if (adjInd contains (row + 1, col)) dirInd = dirInd ++ Array((0.2, (row + 1, col))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row - 1, col)) dirInd = dirInd ++ Array((0.2, (row - 1, col))) else dirInd = dirInd ++ Array((0.2, (row, col)))
if (adjInd contains (row + 1, col)) dirInd = dirInd ++ Array((0.1, (row + 1, col))) else dirInd = dirInd ++ Array((0.1, (row, col)))
if (adjInd contains (row - 1, col)) dirInd = dirInd ++ Array((0.1, (row - 1, col))) else dirInd = dirInd ++ Array((0.1, (row, col)))
}
dirInd
}

def valueIterateAtIndexWithAction(dir: String, row: Int, col: Int): Unit = {
def valueIterateAtIndexWithAction(dir: String, row: Int, col: Int): Double = {
val indices = getActionIndices(dir, row, col)
var valueAtIndex = 0.0
for (i <- indices){
val p = i._1
val t = i._2
val valueAtIndex = p * (rewardGrid(t._1)(t._2) + discountFactor * valueGrid(t._1)(t._2))
valueGrid(t._1)(t._2) = valueGrid(row)(col) + valueAtIndex
valueAtIndex = valueAtIndex + p * (rewardGrid(t._1)(t._2) + discountFactor * valueGrid(t._1)(t._2))
}
valueAtIndex
}

def valueIterateAtIndex(row: Int, col: Int): Unit = {
val actions = Array("N", "S", "E", "W")
for(a <- actions){
valueIterateAtIndexWithAction(a, row, col)
var maxA = valueIterateAtIndexWithAction("N", row, col)
for(a <- Array("S", "E", "W")){
if (valueIterateAtIndexWithAction(a, row, col) > maxA) maxA = valueIterateAtIndexWithAction(a, row, col)
}
valueGrid(row)(col) = maxA
}

def valueIterateFullGrid(n: Int): Unit = {
for(i <- 0 to n){
for(r <- 0 to valueGrid.length - 1 ){
for(c <- 0 to valueGrid(r).length - 1 ){
valueIterateAtIndex(r, c)
if( this.terminalStatesIndex contains (r,c) ){
this.valueGrid(r)(c) = 0.0
} else {
valueIterateAtIndex(r, c)
}

}
}
print(valueGrid.deep)
Expand All @@ -117,4 +130,5 @@ object MDP {
// Corner point condition
// Use conention left right up down
// Grid world initialization
}
}

0 comments on commit f331f84

Please sign in to comment.