-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Student.scala
131 lines (102 loc) · 2.7 KB
/
Student.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package lecture
import java.util.Random
/**
* See tests [[StudentClassSpec]]
*/
sealed trait State {
def reward: Int
def nextRandomState(): State
def possibleNextStates(): List[(State, Double)]
}
object State {
case object Facebook extends State {
val random = new Random()
val reward = -1
def nextRandomState() =
if (State.Facebook.random.nextDouble() <= 0.9)
this
else
State.Class1
def possibleNextStates(): List[(State, Double)] =
List((this, 0.9), (State.Class1, 0.1))
}
case object Class1 extends State {
val random = new Random()
val reward = -2
def nextRandomState() =
if (State.Class1.random.nextDouble() <= 0.5)
State.Facebook
else
State.Class2
def possibleNextStates() =
List((State.Facebook, 0.5), (State.Class2, 0.5))
}
case object Class2 extends State {
val random = new Random()
val reward = -2
def nextRandomState() =
if (State.Class2.random.nextDouble() <= 0.8)
State.Class3
else
State.Sleep
def possibleNextStates() =
List((State.Class3, 0.8), (State.Sleep, 0.2))
}
case object Class3 extends State {
val random = new Random()
val reward = -2
def nextRandomState() =
if (State.Class3.random.nextDouble() <= 0.6)
State.Pass
else
State.Pub
def possibleNextStates() =
List((State.Pass, 0.6), (State.Pub, 0.4))
}
case object Pub extends State {
val random = new Random()
val reward = -1
override def nextRandomState(): State = {
val probability = State.Pub.random.nextDouble()
if (probability <= 0.4)
State.Class3
else if (probability <= 0.8)
State.Class2
else
State.Class1
}
def possibleNextStates() =
List((State.Class1, 0.2), (State.Class2, 0.4), (State.Class3, 0.4))
}
case object Pass extends State {
val reward = 10
def nextRandomState() =
State.Sleep
def possibleNextStates() =
List((State.Sleep, 1.0))
}
case object Sleep extends State {
val reward = 0
def nextRandomState() =
State.Sleep
def possibleNextStates() =
List.empty
}
}
object Agent {
def calculateValue(sample: Iterable[State], discountFactor: Option[Double]): Double =
sample.zipWithIndex.foldLeft(0.0) {
case (totalReward, (state, position)) =>
val discountRate =
if (position == 0)
1
else
discountFactor match {
case Some(discountFactor) =>
Math.pow(discountFactor, position)
case None =>
1
}
totalReward + (state.reward * discountRate)
}
}