Skip to content

Commit

Permalink
refactor: support alternative drivers besides postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewhrit committed Aug 7, 2023
1 parent aa0fdb2 commit fdc0cfc
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 64 deletions.
15 changes: 8 additions & 7 deletions cmd/spirit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ func init() {
Err(err).
Msg("Could not load config")
}
}

func main() {
pg, err := database.NewPostgres()

// Start server and initialize database
if err := database.Init(); err != nil {
if err != nil {
log.Fatal().
Err(err).
Msg("Could not connect to database")
}
}

func main() {
m := server.NewServer(&config.Config, database.Connection)
m := server.NewServer(&config.Config, pg)

m.MountMiddleware()
m.RegisterHeaders()
Expand Down Expand Up @@ -100,7 +101,7 @@ func main() {
}

// Database
err := database.Close()
err := pg.Close()

if err != nil {
log.Fatal().
Expand All @@ -117,7 +118,7 @@ func main() {
Msg("Starting HTTP listener")

// Start the server
err := srv.ListenAndServe()
err = srv.ListenAndServe()

if err != nil && err != http.ErrServerClosed {
log.Fatal().
Expand Down
43 changes: 12 additions & 31 deletions internal/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,23 @@
package database

import (
"database/sql"
"context"
"time"

_ "github.com/lib/pq"
"github.com/orca-group/spirit/internal/config"
)

// Connection holds the current connection to the database
var Connection *sql.DB

func migrate() error {
_, err := Connection.Exec(`
CREATE TABLE IF NOT EXISTS documents (
id varchar(255) PRIMARY KEY,
content text NOT NULL,
created_at timestamp with time zone DEFAULT now(),
updated_at timestamp with time zone DEFAULT now()
)`)

return err
type Document struct {
ID string `db:"id" json:"id"`
Content string `db:"content" json:"content"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}

// Init opens a connection to the database
func Init() (err error) {
Connection, err = sql.Open("postgres", config.Config.ConnectionURI)

if err != nil {
return err
}

if err := migrate(); err != nil {
return err
}

return nil
}
type Database interface {
Migrate(ctx context.Context) error
Close() error

func Close() error {
return nil
GetDocument(ctx context.Context, id string) (Document, error)
CreateDocument(ctx context.Context, id, content string) error
}
44 changes: 44 additions & 0 deletions internal/database/database_mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2020-2023 Luke Whritenour
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package database

import (
"context"
"database/sql"

_ "github.com/lib/pq"
)

type Mock struct {
*sql.DB
}

func NewMock() (Database, error) {
return &Mock{&sql.DB{}}, nil
}

func (m *Mock) Migrate(ctx context.Context) error {
return nil
}

func (m *Mock) GetDocument(ctx context.Context, id string) (Document, error) {
return Document{}, nil
}

func (m *Mock) CreateDocument(ctx context.Context, id, content string) error {
return nil
}
40 changes: 29 additions & 11 deletions internal/database/document.go → internal/database/database_pg.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2023 Luke Whritenour, Jack Dorland
* Copyright 2020-2023 Luke Whritenour
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,27 +17,45 @@
package database

import (
"context"
"database/sql"
"time"

_ "github.com/lib/pq"
"github.com/orca-group/spirit/internal/config"
)

type Document struct {
ID string `db:"id" json:"id"`
Content string `db:"content" json:"content"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
type Postgres struct {
*sql.DB
}

func NewPostgres() (Database, error) {
db, err := sql.Open("postgres", config.Config.ConnectionURI)

return &Postgres{db}, err
}

func (p *Postgres) Migrate(ctx context.Context) error {
_, err := p.Exec(`
CREATE TABLE IF NOT EXISTS documents (
id varchar(255) PRIMARY KEY,
content text NOT NULL,
created_at timestamp with time zone DEFAULT now(),
updated_at timestamp with time zone DEFAULT now()
)`)

return err
}

func FindDocument(db *sql.DB, id string) (Document, error) {
func (p *Postgres) GetDocument(ctx context.Context, id string) (Document, error) {
doc := new(Document)
row := db.QueryRow("SELECT * FROM documents WHERE id=$1", id)
row := p.QueryRow("SELECT * FROM documents WHERE id=$1", id)
err := row.Scan(&doc.ID, &doc.Content, &doc.CreatedAt, &doc.UpdatedAt)

return *doc, err
}

func CreateDocument(db *sql.DB, id, content string) error {
tx, err := db.Begin()
func (p *Postgres) CreateDocument(ctx context.Context, id, content string) error {
tx, err := p.Begin()

if err != nil {
return err
Expand Down
6 changes: 4 additions & 2 deletions internal/server/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package server

import (
"database/sql"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/orca-group/spirit/internal/config"
"github.com/orca-group/spirit/internal/database"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -65,7 +65,9 @@ func checkResponseCode(t *testing.T, expected, actual int) {
}

func TestConfig(t *testing.T) {
s := NewServer(&mockConfig, &sql.DB{})
mock, _ := database.NewMock()

s := NewServer(&mockConfig, mock)
s.MountHandlers()

req, _ := http.NewRequest("GET", "/config", nil)
Expand Down
9 changes: 4 additions & 5 deletions internal/server/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"fmt"
"net/http"

"github.com/orca-group/spirit/internal/database"
"github.com/orca-group/spirit/internal/util"
)

Expand All @@ -42,8 +41,8 @@ func createDocument(s *Server, w http.ResponseWriter, r *http.Request) string {
// Add Document object to database
id := util.GenerateID(s.Config.IDType, s.Config.IDLength)

if err := database.CreateDocument(
s.Database,
if err := s.Database.CreateDocument(
r.Context(),
id,
body.Content,
); err != nil {
Expand All @@ -57,7 +56,7 @@ func createDocument(s *Server, w http.ResponseWriter, r *http.Request) string {
func (s *Server) CreateDocument(w http.ResponseWriter, r *http.Request) {
// Create document, then pull it from the database
id := createDocument(s, w, r)
document, err := database.FindDocument(s.Database, id)
document, err := s.Database.GetDocument(r.Context(), id)

if err != nil {
util.WriteError(w, http.StatusInternalServerError, err)
Expand All @@ -74,7 +73,7 @@ func (s *Server) CreateDocument(w http.ResponseWriter, r *http.Request) {
func (s *Server) StaticCreateDocument(w http.ResponseWriter, r *http.Request) {
// Create document, then pull it from the database
id := createDocument(s, w, r)
document, err := database.FindDocument(s.Database, id)
document, err := s.Database.GetDocument(r.Context(), id)

if err != nil {
util.WriteError(w, http.StatusInternalServerError, err)
Expand Down
11 changes: 6 additions & 5 deletions internal/server/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package server

import (
"context"
"database/sql"
"errors"
"fmt"
Expand All @@ -30,9 +31,9 @@ import (
"golang.org/x/exp/slices"
)

func getDocument(s *Server, w http.ResponseWriter, id string) database.Document {
func getDocument(s *Server, w http.ResponseWriter, ctx context.Context, id string) database.Document {
// Retrieve document from the database
document, err := database.FindDocument(s.Database, id)
document, err := s.Database.GetDocument(ctx, id)

if err != nil {
// If the document is not found (ErrNoRows), return the error with a 404
Expand Down Expand Up @@ -61,7 +62,7 @@ func (s *Server) StaticDocument(w http.ResponseWriter, r *http.Request) {
}

// Retrieve document from the database
document := getDocument(s, w, id)
document := getDocument(s, w, r.Context(), id)

t, err := template.ParseFS(resources, "web/document.html")

Expand Down Expand Up @@ -98,7 +99,7 @@ func (s *Server) FetchDocument(w http.ResponseWriter, r *http.Request) {
return
}

document := getDocument(s, w, id)
document := getDocument(s, w, r.Context(), id)

// Try responding with the document and a 200, or write an error if that fails
if err := util.WriteJSON(w, http.StatusOK, document); err != nil {
Expand All @@ -117,7 +118,7 @@ func (s *Server) FetchRawDocument(w http.ResponseWriter, r *http.Request) {
return
}

document := getDocument(s, w, id)
document := getDocument(s, w, r.Context(), id)

// Respond with only the documents content
w.WriteHeader(http.StatusOK)
Expand Down
6 changes: 3 additions & 3 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package server

import (
"database/sql"
"embed"
"io/fs"
"net/http"
Expand All @@ -28,6 +27,7 @@ import (
"github.com/go-chi/cors"
"github.com/go-chi/httprate"
"github.com/orca-group/spirit/internal/config"
"github.com/orca-group/spirit/internal/database"
"github.com/orca-group/spirit/internal/util"
"github.com/rs/zerolog/log"
)
Expand All @@ -38,10 +38,10 @@ var resources embed.FS
type Server struct {
Router *chi.Mux
Config *config.Cfg
Database *sql.DB
Database database.Database
}

func NewServer(config *config.Cfg, db *sql.DB) *Server {
func NewServer(config *config.Cfg, db database.Database) *Server {
s := &Server{}
s.Router = chi.NewRouter()
s.Config = config
Expand Down

0 comments on commit fdc0cfc

Please sign in to comment.