-
Notifications
You must be signed in to change notification settings - Fork 1
/
svm.go
50 lines (44 loc) · 932 Bytes
/
svm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package main
import (
"fmt"
"math"
)
type SVM struct {
weights []float64
bias float64
c float64
}
func NewSVM(c float64) *SVM {
return &SVM{c: c}
}
func (svm *SVM) Fit(X [][]float64, y []float64, epochs int, alpha float64) {
nSamples, nFeatures := len(X), len(X[0])
svm.weights = make([]float64, nFeatures)
svm.bias = 0.0
for epoch := 0; epoch < epochs; epoch++ {
for i := 0; i < nSamples; i++ {
prediction := svm.Predict(X[i])
if y[i]*prediction < 1 {
for j := 0; j < nFeatures; j++ {
svm.weights[j] += alpha * (y[i]*X[i][j] - svm.c*svm.weights[j])
}
svm.bias += alpha * y[i]
} else {
for j := 0; j < nFeatures; j++ {
svm.weights[j] += alpha * (-svm.c * svm.weights[j])
}
}
}
}
}
func (svm *SVM) Predict(X []float64) float64 {
z := 0.0
for i := 0; i < len(X); i++ {
z += svm.weights[i] * X[i]
}
z += svm.bias
if z >= 0 {
return 1.0
}
return -1.0
}