Skip to content

Commit

Permalink
Abstract DB so we can switch between MongoDB and SQLite.
Browse files Browse the repository at this point in the history
SQLite is only used for local development.
  • Loading branch information
robholland committed Nov 20, 2024
1 parent 7ffc6c9 commit d54f0f0
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 103 deletions.
194 changes: 194 additions & 0 deletions app/db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package db

import (
"context"
_ "embed"
"fmt"
"time"

"github.com/jmoiron/sqlx"
"github.com/temporalio/reference-app-orders-go/app/config"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
mongodb "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

// OrdersCollection is the name of the MongoDB collection to use for Orders.
const OrdersCollection = "orders"

// ShipmentCollection is the name of the MongoDB collection to use for Shipment data.
const ShipmentCollection = "shipments"

// DB is an interface that defines the methods that a database driver must implement
type DB interface {
Connect(ctx context.Context) error
Setup() error
Close() error
InsertOrder(context.Context, interface{}) error
UpdateOrderStatus(context.Context, string, string) error
GetOrders(context.Context, interface{}) error
UpdateShipmentStatus(context.Context, string, string) error
GetShipments(context.Context, interface{}) error
}

// CreateDB creates a new DB instance based on the configuration
func CreateDB(config config.AppConfig) DB {
if config.MongoURL != "" {
return &MongoDB{uri: config.MongoURL}
}

return &SQLiteDB{path: "./api-data.db"}
}

// MongoDB is a struct that implements the DB interface for MongoDB
type MongoDB struct {
uri string
client *mongo.Client
db *mongo.Database
}

// Connect connects to a MongoDB instance
func (m *MongoDB) Connect(ctx context.Context) error {
client, err := mongo.Connect(ctx, options.Client().ApplyURI(m.uri))
if err != nil {
return err
}
m.client = client
m.db = client.Database("orders")
return nil
}

// Setup sets up the MongoDB instance
func (m *MongoDB) Setup() error {
orders := m.db.Collection(OrdersCollection)
_, err := orders.Indexes().CreateOne(context.TODO(), mongodb.IndexModel{
Keys: map[string]interface{}{"received_at": 1},
})
if err != nil {
return fmt.Errorf("failed to create orders index: %w", err)
}

shipments := m.db.Collection(ShipmentCollection)
_, err = shipments.Indexes().CreateOne(context.TODO(), mongodb.IndexModel{
Keys: map[string]interface{}{"booked_at": 1},
})
if err != nil {
return fmt.Errorf("failed to create shipment index: %w", err)
}

return nil
}

// InsertOrder inserts an Order into the MongoDB instance
func (m *MongoDB) InsertOrder(ctx context.Context, order interface{}) error {
_, err := m.db.Collection(OrdersCollection).InsertOne(ctx, order)
return err
}

// UpdateOrderStatus updates an Order in the MongoDB instance
func (m *MongoDB) UpdateOrderStatus(ctx context.Context, id string, status string) error {
_, err := m.db.Collection(OrdersCollection).UpdateOne(ctx, bson.M{"id": id}, bson.M{"$set": bson.M{"status": status}})
return err
}

// GetOrders returns a list of Orders from the MongoDB instance
func (m *MongoDB) GetOrders(ctx context.Context, result interface{}) error {
res, err := m.db.Collection(OrdersCollection).Find(ctx, bson.M{}, &options.FindOptions{
Sort: bson.M{"received_at": 1},
})
if err != nil {
return err
}

return res.All(ctx, result)
}

// UpdateShipmentStatus updates a Shipment in the MongoDB instance
func (m *MongoDB) UpdateShipmentStatus(ctx context.Context, id string, status string) error {
_, err := m.db.Collection(ShipmentCollection).UpdateOne(
ctx,
bson.M{"id": id},
bson.M{
"$set": bson.M{"status": status},
"$setOnInsert": bson.M{"booked_at": time.Now().UTC()},
},
)
return err
}

// GetShipments returns a list of Shipments from the MongoDB instance
func (m *MongoDB) GetShipments(ctx context.Context, result interface{}) error {
res, err := m.db.Collection(ShipmentCollection).Find(ctx, bson.M{}, &options.FindOptions{
Sort: bson.M{"booked_at": 1},
})
if err != nil {
return err
}

return res.All(ctx, result)
}

// Close closes the connection to the MongoDB instance
func (m *MongoDB) Close() error {
return m.client.Disconnect(context.Background())
}

// SQLiteDB is a struct that implements the DB interface for SQLite
type SQLiteDB struct {
path string
db *sqlx.DB
}

//go:embed schema.sql
var sqliteSchema string

// Connect connects to a SQLite instance
func (s *SQLiteDB) Connect(_ context.Context) error {
db, err := sqlx.Connect("sqlite3", s.path)
if err != nil {
return err
}
s.db = db
db.SetMaxOpenConns(1) // SQLite does not support concurrent writes
return nil
}

// Setup sets up the SQLite instance
func (s *SQLiteDB) Setup() error {
_, err := s.db.Exec(sqliteSchema)
return err
}

// Close closes the connection to the SQLite instance
func (s *SQLiteDB) Close() error {
return s.db.Close()
}

// InsertOrder inserts an Order into the SQLite instance
func (s *SQLiteDB) InsertOrder(ctx context.Context, order interface{}) error {
_, err := s.db.NamedExecContext(ctx, "INSERT INTO orders (id, received_at, status) VALUES (:id, :received_at, :status)", order)
return err
}

// UpdateOrderStatus updates an Order in the SQLite instance
func (s *SQLiteDB) UpdateOrderStatus(ctx context.Context, id string, status string) error {
_, err := s.db.ExecContext(ctx, "UPDATE orders SET status = ? WHERE id = ?", status, id)
return err
}

// GetOrders returns a list of Orders from the SQLite instance
func (s *SQLiteDB) GetOrders(ctx context.Context, result interface{}) error {
return s.db.SelectContext(ctx, result, "SELECT * FROM orders ORDER BY received_at")
}

// UpdateShipmentStatus updates a Shipment in the SQLite instance
func (s *SQLiteDB) UpdateShipmentStatus(ctx context.Context, id string, status string) error {
_, err := s.db.ExecContext(ctx, "INSERT INTO shipments (id, booked_at, status) VALUES (?, ?, ?) ON CONFLICT(id) DO UPDATE SET status = ?", id, time.Now().UTC(), status, status)
return err
}

// GetShipments returns a list of Shipments from the SQLite instance
func (s *SQLiteDB) GetShipments(ctx context.Context, result interface{}) error {
return s.db.SelectContext(ctx, result, "SELECT * FROM shipments ORDER BY booked_at")
}
16 changes: 16 additions & 0 deletions app/db/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
CREATE TABLE IF NOT EXISTS orders (
id TEXT PRIMARY KEY,
customer_id TEXT NOT NULL,
received_at TIMESTAMP NOT NULL,
status TEXT NOT NULL
);

CREATE INDEX IF NOT EXISTS orders_received_at ON orders(received_at DESC);

CREATE TABLE IF NOT EXISTS shipments (
id TEXT PRIMARY KEY,
status TEXT NOT NULL,
booked_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

CREATE INDEX IF NOT EXISTS shipments_booked_at ON shipments (booked_at DESC);
31 changes: 8 additions & 23 deletions app/order/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
"strings"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/temporalio/reference-app-orders-go/app/db"
"go.temporal.io/api/enums/v1"
"go.temporal.io/api/serviceerror"
"go.temporal.io/sdk/client"
Expand All @@ -23,9 +21,6 @@ const TaskQueue = "orders"
// StatusQuery is the name of the query to use to fetch an Order's status.
const StatusQuery = "status"

// OrdersCollection is the name of the MongoDB collection to use for Orders.
const OrdersCollection = "orders"

// OrderWorkflowID returns the workflow ID for an Order.
func OrderWorkflowID(id string) string {
return "Order:" + id
Expand Down Expand Up @@ -201,16 +196,15 @@ type OrderResult struct {

type handlers struct {
temporal client.Client
orders *mongo.Collection
db db.DB
logger *slog.Logger
}

// Router implements the http.Handler interface for the Billing API
func Router(client client.Client, db *mongo.Database, logger *slog.Logger) http.Handler {
func Router(client client.Client, db db.DB, logger *slog.Logger) http.Handler {
r := http.NewServeMux()

orders := db.Collection(OrdersCollection)
h := handlers{temporal: client, orders: orders, logger: logger}
h := handlers{temporal: client, db: db, logger: logger}

r.HandleFunc("POST /orders", h.handleCreateOrder)
r.HandleFunc("GET /orders", h.handleListOrders)
Expand All @@ -223,18 +217,9 @@ func Router(client client.Client, db *mongo.Database, logger *slog.Logger) http.

func (h *handlers) handleListOrders(w http.ResponseWriter, _ *http.Request) {
ctx := context.TODO()
orders := []ListOrderEntry{}

res, err := h.orders.Find(ctx, bson.M{}, &options.FindOptions{
Sort: bson.M{"received_at": 1},
})
if err != nil {
h.logger.Error("Failed to list orders", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

err = res.All(ctx, &orders)
var orders []ListOrderEntry
err := h.db.GetOrders(ctx, &orders)
if err != nil {
h.logger.Error("Failed to list orders", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -282,7 +267,7 @@ func (h *handlers) handleCreateOrder(w http.ResponseWriter, r *http.Request) {
Status: OrderStatusPending,
}

_, err = h.orders.InsertOne(context.Background(), status)
err = h.db.InsertOrder(context.Background(), status)
if err != nil {
h.logger.Error("Failed to record workflow status", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -334,7 +319,7 @@ func (h *handlers) handleUpdateOrderStatus(w http.ResponseWriter, r *http.Reques
return
}

_, err = h.orders.UpdateOne(context.Background(), bson.M{"id": status.ID}, bson.M{"$set": bson.M{"status": status.Status}})
err = h.db.UpdateOrderStatus(context.Background(), status.ID, status.Status)
if err != nil {
h.logger.Error("Failed to update order status", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
34 changes: 5 additions & 29 deletions app/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package server
import (
"context"
"crypto/tls"
_ "embed"
"fmt"
"log/slog"
"net/http"
Expand All @@ -13,11 +12,10 @@ import (

"github.com/temporalio/reference-app-orders-go/app/billing"
"github.com/temporalio/reference-app-orders-go/app/config"
"github.com/temporalio/reference-app-orders-go/app/db"
"github.com/temporalio/reference-app-orders-go/app/fraud"
"github.com/temporalio/reference-app-orders-go/app/order"
"github.com/temporalio/reference-app-orders-go/app/shipment"
mongodb "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.temporal.io/sdk/client"
"go.temporal.io/sdk/log"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -68,27 +66,6 @@ func CreateClientOptionsFromEnv() (client.Options, error) {
return clientOpts, nil
}

// SetupDB creates indexes in the database.
func SetupDB(db *mongodb.Database) error {
orders := db.Collection(order.OrdersCollection)
_, err := orders.Indexes().CreateOne(context.TODO(), mongodb.IndexModel{
Keys: map[string]interface{}{"received_at": 1},
})
if err != nil {
return fmt.Errorf("failed to create database index: %w", err)
}

shipments := db.Collection(shipment.ShipmentCollection)
_, err = shipments.Indexes().CreateOne(context.TODO(), mongodb.IndexModel{
Keys: map[string]interface{}{"booked_at": 1},
})
if err != nil {
return fmt.Errorf("failed to create database index: %w", err)
}

return nil
}

// RunWorkers runs workers for the requested services.
func RunWorkers(ctx context.Context, config config.AppConfig, client client.Client, services []string) error {
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -184,16 +161,15 @@ func RunAPIServers(ctx context.Context, config config.AppConfig, client client.C

g, ctx := errgroup.WithContext(ctx)

var db *mongodb.Database
db := db.CreateDB(config)

if slices.Contains(services, "orders") || slices.Contains(services, "shipment") {
c, err := mongodb.Connect(context.TODO(), options.Client().ApplyURI(config.MongoURL))
db = c.Database("orders")
err := db.Connect(context.TODO())
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
return fmt.Errorf("failed to connect to database: %w", err)
}

if err := SetupDB(db); err != nil {
if err := db.Setup(); err != nil {
return err
}
}
Expand Down
Loading

0 comments on commit d54f0f0

Please sign in to comment.