diff --git a/cmd/spirit/main.go b/cmd/spirit/main.go index 1bebcdfb..6a6600ff 100644 --- a/cmd/spirit/main.go +++ b/cmd/spirit/main.go @@ -44,17 +44,18 @@ func init() { Err(err). Msg("Could not load config") } +} + +func main() { + pg, err := database.NewPostgres() - // Start server and initialize database - if err := database.Init(); err != nil { + if err != nil { log.Fatal(). Err(err). Msg("Could not connect to database") } -} -func main() { - m := server.NewServer(&config.Config, database.Connection) + m := server.NewServer(&config.Config, pg) m.MountMiddleware() m.RegisterHeaders() @@ -100,7 +101,7 @@ func main() { } // Database - err := database.Close() + err := pg.Close() if err != nil { log.Fatal(). @@ -117,7 +118,7 @@ func main() { Msg("Starting HTTP listener") // Start the server - err := srv.ListenAndServe() + err = srv.ListenAndServe() if err != nil && err != http.ErrServerClosed { log.Fatal(). diff --git a/internal/database/database.go b/internal/database/database.go index 63fea804..6ef45502 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -17,42 +17,23 @@ package database import ( - "database/sql" + "context" + "time" _ "github.com/lib/pq" - "github.com/orca-group/spirit/internal/config" ) -// Connection holds the current connection to the database -var Connection *sql.DB - -func migrate() error { - _, err := Connection.Exec(` -CREATE TABLE IF NOT EXISTS documents ( - id varchar(255) PRIMARY KEY, - content text NOT NULL, - created_at timestamp with time zone DEFAULT now(), - updated_at timestamp with time zone DEFAULT now() -)`) - - return err +type Document struct { + ID string `db:"id" json:"id"` + Content string `db:"content" json:"content"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } -// Init opens a connection to the database -func Init() (err error) { - Connection, err = sql.Open("postgres", config.Config.ConnectionURI) - - if err != nil { - return err - } - - if err := migrate(); err != nil { - return err - } - - return nil -} +type Database interface { + Migrate(ctx context.Context) error + Close() error -func Close() error { - return nil + GetDocument(ctx context.Context, id string) (Document, error) + CreateDocument(ctx context.Context, id, content string) error } diff --git a/internal/database/database_mock.go b/internal/database/database_mock.go new file mode 100644 index 00000000..b623a92c --- /dev/null +++ b/internal/database/database_mock.go @@ -0,0 +1,44 @@ +/* + * Copyright 2020-2023 Luke Whritenour + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package database + +import ( + "context" + "database/sql" + + _ "github.com/lib/pq" +) + +type Mock struct { + *sql.DB +} + +func NewMock() (Database, error) { + return &Mock{&sql.DB{}}, nil +} + +func (m *Mock) Migrate(ctx context.Context) error { + return nil +} + +func (m *Mock) GetDocument(ctx context.Context, id string) (Document, error) { + return Document{}, nil +} + +func (m *Mock) CreateDocument(ctx context.Context, id, content string) error { + return nil +} diff --git a/internal/database/document.go b/internal/database/database_pg.go similarity index 53% rename from internal/database/document.go rename to internal/database/database_pg.go index a366e440..0d018ce7 100644 --- a/internal/database/document.go +++ b/internal/database/database_pg.go @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 Luke Whritenour, Jack Dorland + * Copyright 2020-2023 Luke Whritenour * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,27 +17,45 @@ package database import ( + "context" "database/sql" - "time" + + _ "github.com/lib/pq" + "github.com/orca-group/spirit/internal/config" ) -type Document struct { - ID string `db:"id" json:"id"` - Content string `db:"content" json:"content"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +type Postgres struct { + *sql.DB +} + +func NewPostgres() (Database, error) { + db, err := sql.Open("postgres", config.Config.ConnectionURI) + + return &Postgres{db}, err +} + +func (p *Postgres) Migrate(ctx context.Context) error { + _, err := p.Exec(` +CREATE TABLE IF NOT EXISTS documents ( + id varchar(255) PRIMARY KEY, + content text NOT NULL, + created_at timestamp with time zone DEFAULT now(), + updated_at timestamp with time zone DEFAULT now() +)`) + + return err } -func FindDocument(db *sql.DB, id string) (Document, error) { +func (p *Postgres) GetDocument(ctx context.Context, id string) (Document, error) { doc := new(Document) - row := db.QueryRow("SELECT * FROM documents WHERE id=$1", id) + row := p.QueryRow("SELECT * FROM documents WHERE id=$1", id) err := row.Scan(&doc.ID, &doc.Content, &doc.CreatedAt, &doc.UpdatedAt) return *doc, err } -func CreateDocument(db *sql.DB, id, content string) error { - tx, err := db.Begin() +func (p *Postgres) CreateDocument(ctx context.Context, id, content string) error { + tx, err := p.Begin() if err != nil { return err diff --git a/internal/server/config_test.go b/internal/server/config_test.go index c23b3a3f..79bebe2b 100644 --- a/internal/server/config_test.go +++ b/internal/server/config_test.go @@ -17,7 +17,6 @@ package server import ( - "database/sql" "encoding/json" "io" "net/http" @@ -25,6 +24,7 @@ import ( "testing" "github.com/orca-group/spirit/internal/config" + "github.com/orca-group/spirit/internal/database" "github.com/stretchr/testify/require" ) @@ -65,7 +65,9 @@ func checkResponseCode(t *testing.T, expected, actual int) { } func TestConfig(t *testing.T) { - s := NewServer(&mockConfig, &sql.DB{}) + mock, _ := database.NewMock() + + s := NewServer(&mockConfig, mock) s.MountHandlers() req, _ := http.NewRequest("GET", "/config", nil) diff --git a/internal/server/create.go b/internal/server/create.go index e1b934b9..c0d55ba5 100644 --- a/internal/server/create.go +++ b/internal/server/create.go @@ -20,7 +20,6 @@ import ( "fmt" "net/http" - "github.com/orca-group/spirit/internal/database" "github.com/orca-group/spirit/internal/util" ) @@ -42,8 +41,8 @@ func createDocument(s *Server, w http.ResponseWriter, r *http.Request) string { // Add Document object to database id := util.GenerateID(s.Config.IDType, s.Config.IDLength) - if err := database.CreateDocument( - s.Database, + if err := s.Database.CreateDocument( + r.Context(), id, body.Content, ); err != nil { @@ -57,7 +56,7 @@ func createDocument(s *Server, w http.ResponseWriter, r *http.Request) string { func (s *Server) CreateDocument(w http.ResponseWriter, r *http.Request) { // Create document, then pull it from the database id := createDocument(s, w, r) - document, err := database.FindDocument(s.Database, id) + document, err := s.Database.GetDocument(r.Context(), id) if err != nil { util.WriteError(w, http.StatusInternalServerError, err) @@ -74,7 +73,7 @@ func (s *Server) CreateDocument(w http.ResponseWriter, r *http.Request) { func (s *Server) StaticCreateDocument(w http.ResponseWriter, r *http.Request) { // Create document, then pull it from the database id := createDocument(s, w, r) - document, err := database.FindDocument(s.Database, id) + document, err := s.Database.GetDocument(r.Context(), id) if err != nil { util.WriteError(w, http.StatusInternalServerError, err) diff --git a/internal/server/fetch.go b/internal/server/fetch.go index d5793ff8..e9044b50 100644 --- a/internal/server/fetch.go +++ b/internal/server/fetch.go @@ -17,6 +17,7 @@ package server import ( + "context" "database/sql" "errors" "fmt" @@ -30,9 +31,9 @@ import ( "golang.org/x/exp/slices" ) -func getDocument(s *Server, w http.ResponseWriter, id string) database.Document { +func getDocument(s *Server, w http.ResponseWriter, ctx context.Context, id string) database.Document { // Retrieve document from the database - document, err := database.FindDocument(s.Database, id) + document, err := s.Database.GetDocument(ctx, id) if err != nil { // If the document is not found (ErrNoRows), return the error with a 404 @@ -61,7 +62,7 @@ func (s *Server) StaticDocument(w http.ResponseWriter, r *http.Request) { } // Retrieve document from the database - document := getDocument(s, w, id) + document := getDocument(s, w, r.Context(), id) t, err := template.ParseFS(resources, "web/document.html") @@ -98,7 +99,7 @@ func (s *Server) FetchDocument(w http.ResponseWriter, r *http.Request) { return } - document := getDocument(s, w, id) + document := getDocument(s, w, r.Context(), id) // Try responding with the document and a 200, or write an error if that fails if err := util.WriteJSON(w, http.StatusOK, document); err != nil { @@ -117,7 +118,7 @@ func (s *Server) FetchRawDocument(w http.ResponseWriter, r *http.Request) { return } - document := getDocument(s, w, id) + document := getDocument(s, w, r.Context(), id) // Respond with only the documents content w.WriteHeader(http.StatusOK) diff --git a/internal/server/server.go b/internal/server/server.go index e948a37f..5633b461 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,6 @@ package server import ( - "database/sql" "embed" "io/fs" "net/http" @@ -28,6 +27,7 @@ import ( "github.com/go-chi/cors" "github.com/go-chi/httprate" "github.com/orca-group/spirit/internal/config" + "github.com/orca-group/spirit/internal/database" "github.com/orca-group/spirit/internal/util" "github.com/rs/zerolog/log" ) @@ -38,10 +38,10 @@ var resources embed.FS type Server struct { Router *chi.Mux Config *config.Cfg - Database *sql.DB + Database database.Database } -func NewServer(config *config.Cfg, db *sql.DB) *Server { +func NewServer(config *config.Cfg, db database.Database) *Server { s := &Server{} s.Router = chi.NewRouter() s.Config = config