Skip to content

Commit

Permalink
feat: add support for authentication plugins (#416)
Browse files Browse the repository at this point in the history
Authentication plugins accept the FTL endpoint URL as their argument and
output one or more HTTP headers to be used for authentication. These
headers will be added to every outbound request.

There's an internal initial working proof of concept authentication
plugin that uses [kooky](https://github.com/browserutils/kooky) to
authenticate using an oauth2-proxy cookie. This seems to work well.

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
alecthomas and github-actions[bot] authored Sep 22, 2023
1 parent dd080ec commit 01b0e76
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 40 deletions.
187 changes: 187 additions & 0 deletions authn/authn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package authn

import (
"bufio"
"context"
"fmt"
"io"
"net/http"
"net/url"
"os/user"
"strings"
"sync"
"time"

"github.com/alecthomas/errors"
"github.com/zalando/go-keyring"

"github.com/TBD54566975/ftl/backend/common/exec"
"github.com/TBD54566975/ftl/backend/common/log"
)

// GetAuthenticationHeaders returns authentication headers for the given endpoint.
//
// "authenticators" are authenticator executables to use for each endpoint. The key is the URL of the endpoint, the
// value is the name/path of the authenticator executable. The authenticator executable will be called with the URL as
// the first argument, and output a list of headers to stdout to use for authentication.
//
// If the endpoint is already authenticated, the existing credentials will be returned. Additionally, credentials will
// be cached across runs in the keyring.
func GetAuthenticationHeaders(ctx context.Context, endpoint *url.URL, authenticators map[string]string) (http.Header, error) {
logger := log.FromContext(ctx).Scope(endpoint.Hostname())

endpoint = &url.URL{
Scheme: endpoint.Scheme,
Host: endpoint.Host,
User: endpoint.User,
}

usr, err := user.Current()
if err != nil {
return nil, errors.WithStack(err)
}

// First, check if we have credentials in the keyring and that they work.
keyringKey := "ftl+" + endpoint.String()
logger.Debugf("Trying keyring key %s", keyringKey)
creds, err := keyring.Get(keyringKey, usr.Name)
if errors.Is(err, keyring.ErrNotFound) {
logger.Tracef("No credentials found in keyring")
} else if err != nil {
return nil, errors.WithStack(err)
} else {
logger.Tracef("Credentials found in keyring: %s", creds)
if headers, err := checkAuth(ctx, logger, endpoint, creds); err != nil {
return nil, errors.WithStack(err)
} else if headers != nil {
return headers, nil
}
}

// Next, try the authenticator.
logger.Debugf("Trying authenticator")
authenticator, ok := authenticators[endpoint.Hostname()]
if !ok {
logger.Tracef("No authenticator found in %s", authenticators)
return nil, nil
}

cmd := exec.Command(ctx, log.Error, ".", authenticator, endpoint.String())
out := &strings.Builder{}
cmd.Stdout = out
err = cmd.Run()
if err != nil {
return nil, errors.Wrapf(err, "authenticator %s failed", authenticator)
}

creds = out.String()
if headers, err := checkAuth(ctx, logger, endpoint, creds); err != nil {
return nil, errors.WithStack(err)
} else if headers != nil {
logger.Debugf("Authenticator %s succeeded", authenticator)
w := &strings.Builder{}
for name, values := range headers {
for _, value := range values {
fmt.Printf("%s: %s\r\n", name, value)
}
}
err = keyring.Set(keyringKey, usr.Name, w.String())
if err != nil {
logger.Warnf("Failed to save credentials to keyring: %s", err)
}
return headers, nil
}

return nil, nil
}

// Check credentials and return authenticating headers if we're able to successfully authenticate.
func checkAuth(ctx context.Context, logger *log.Logger, endpoint *url.URL, creds string) (http.Header, error) {
// Parse the headers
headers := http.Header{}
buf := bufio.NewScanner(strings.NewReader(creds))
logger.Tracef("Parsing credentials")
for buf.Scan() {
line := buf.Text()
name, value, ok := strings.Cut(line, ":")
if !ok {
return nil, errors.Errorf("invalid header %q", line)
}
headers[name] = append(headers[name], strings.TrimSpace(value))
}
if buf.Err() != nil {
return nil, errors.WithStack(buf.Err())
}

// Issue a HEAD request with the headers to verify we get a 200 back.
client := &http.Client{
Timeout: time.Second * 5,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodHead, endpoint.String(), nil)
if err != nil {
return nil, errors.WithStack(err)
}
logger.Debugf("Authentication probe: %s %s", req.Method, req.URL)
for header, values := range headers {
for _, value := range values {
req.Header.Add(header, value)
}
}
logger.Tracef("Authenticating with headers %s", headers)
resp, err := client.Do(req)
if err != nil {
return nil, errors.WithStack(err)
}
defer resp.Body.Close() //nolint:gosec
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Debugf("Endpoint returned %d for authenticated request", resp.StatusCode)
logger.Debugf("Response headers: %s", resp.Header)
logger.Debugf("Response body: %s", body)
return nil, nil
}
logger.Debugf("Successfully authenticated with %s", headers)
return headers, nil
}

// Transport returns a transport that will authenticate requests to the given endpoints.
func Transport(next http.RoundTripper, authenticators map[string]string) http.RoundTripper {
return &authnTransport{
authenticators: authenticators,
credentials: map[string]http.Header{},
next: next,
}
}

type authnTransport struct {
lock sync.RWMutex
authenticators map[string]string
credentials map[string]http.Header
next http.RoundTripper
}

func (a *authnTransport) RoundTrip(r *http.Request) (*http.Response, error) {
a.lock.RLock()
creds, ok := a.credentials[r.URL.Hostname()]
a.lock.RUnlock()
if !ok {
var err error
creds, err = GetAuthenticationHeaders(r.Context(), r.URL, a.authenticators)
if err != nil {
return nil, errors.Wrapf(err, "failed to get authentication headers for %s", r.URL.Hostname())
}
a.lock.Lock()
a.credentials[r.URL.Hostname()] = creds
a.lock.Unlock()
}
for header, values := range creds {
for _, value := range values {
r.Header.Add(header, value)
}
}
resp, err := a.next.RoundTrip(r)
return resp, errors.WithStack(err)
}
66 changes: 37 additions & 29 deletions backend/common/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,67 @@ import (
"github.com/jpillora/backoff"
"golang.org/x/net/http2"

"github.com/TBD54566975/ftl/authn"
"github.com/TBD54566975/ftl/backend/common/log"
ftlv1 "github.com/TBD54566975/ftl/protos/xyz/block/ftl/v1"
)

var (
dialer = &net.Dialer{
Timeout: time.Second * 10,
}
h2cClient = func() *http.Client {
var netTransport = &http2.Transport{
// InitialiseClients HTTP clients used by the RPC system.
//
// "authenticators" are authenticator executables to use for each endpoint. The key is the URL of the endpoint, the
// value is the path to the authenticator executable.
func InitialiseClients(authenticators map[string]string) {
// We can't have a client-wide timeout because it also applies to
// streaming RPCs, timing them out.
h2cClient = &http.Client{
Transport: authn.Transport(&http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
conn, err := dialer.Dial(network, addr)
return conn, errors.WithStack(err)
},
}
return &http.Client{
// We can't have a client-wide timeout because it also applies to
// streaming RPCs, timing them out.
// Timeout: time.Second * 10,
Transport: netTransport,
}
}()
tlsClient = func() *http.Client {
netTransport := &http2.Transport{
}, authenticators),
}
tlsClient = &http.Client{
Transport: authn.Transport(&http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
tlsDialer := tls.Dialer{Config: config, NetDialer: dialer}
conn, err := tlsDialer.DialContext(ctx, network, addr)
return conn, errors.WithStack(err)
},
}
return &http.Client{
// We can't have a client-wide timeout because it also applies to
// streaming RPCs, timing them out.
// Timeout: time.Second * 10,
Transport: netTransport,
}
}()
}, authenticators),
}
}

var (
dialer = &net.Dialer{
Timeout: time.Second * 10,
}
h2cClient *http.Client
tlsClient *http.Client
)

type Pingable interface {
Ping(context.Context, *connect.Request[ftlv1.PingRequest]) (*connect.Response[ftlv1.PingResponse], error)
}

// GetHTTPClient returns a HTTP client usable for the given URL.
func GetHTTPClient(url string) *http.Client {
if h2cClient == nil {
panic("rpc.InitialiseClients() must be called before GetHTTPClient()")
}
if strings.HasPrefix(url, "http://") {
return h2cClient
}
return tlsClient
}

// ClientFactory is a function that creates a new client and is typically one of
// the New*Client functions generated by protoc-gen-connect-go.
type ClientFactory[Client Pingable] func(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) Client

func Dial[Client Pingable](factory ClientFactory[Client], baseURL string, errorLevel log.Level, opts ...connect.ClientOption) Client {
client := tlsClient
if strings.HasPrefix(baseURL, "http://") {
client = h2cClient
}
client := GetHTTPClient(baseURL)
opts = append(opts, DefaultClientOptions(errorLevel)...)
return factory(client, baseURL, opts...)
}
Expand Down
17 changes: 12 additions & 5 deletions cmd/ftl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"syscall"

"github.com/alecthomas/kong"
kongtoml "github.com/alecthomas/kong-toml"
"github.com/bufbuild/connect-go"

_ "github.com/TBD54566975/ftl/backend/common/automaxprocs" // Set GOMAXPROCS to match Linux container CPU quota.
log2 "github.com/TBD54566975/ftl/backend/common/log"
"github.com/TBD54566975/ftl/backend/common/log"
"github.com/TBD54566975/ftl/backend/common/rpc"
"github.com/TBD54566975/ftl/protos/xyz/block/ftl/v1/ftlv1connect"
)
Expand All @@ -21,9 +22,12 @@ var version = "dev"

type CLI struct {
Version kong.VersionFlag `help:"Show version."`
LogConfig log2.Config `embed:"" prefix:"log-" group:"Logging:"`
Config kong.ConfigFlag `help:"Load configuration from TOML file." placeholder:"FILE"`
LogConfig log.Config `embed:"" prefix:"log-" group:"Logging:"`
Endpoint *url.URL `default:"http://127.0.0.1:8892" help:"FTL endpoint to bind/connect to." env:"FTL_ENDPOINT"`

Authenticators map[string]string `help:"Authenticators to use for FTL endpoints." mapsep:"," env:"FTL_AUTHENTICATORS" placeholder:"HOST=EXE,…"`

Status statusCmd `cmd:"" help:"Show FTL status."`
PS psCmd `cmd:"" help:"List deployments."`
Serve serveCmd `cmd:"" help:"Start the FTL server."`
Expand All @@ -40,6 +44,7 @@ var cli CLI
func main() {
kctx := kong.Parse(&cli,
kong.Description(`FTL - Towards a 𝝺-calculus for large-scale systems`),
kong.Configuration(kongtoml.Loader, ".ftl.toml", "~/.ftl.toml"),
kong.AutoGroup(func(parent kong.Visitable, flag *kong.Flag) *kong.Group {
node, ok := parent.(*kong.Command)
if !ok {
Expand All @@ -54,13 +59,15 @@ func main() {
},
)

rpc.InitialiseClients(cli.Authenticators)

// Set the log level for child processes.
os.Setenv("LOG_LEVEL", cli.LogConfig.Level.String())

ctx, cancel := context.WithCancel(context.Background())

logger := log2.Configure(os.Stderr, cli.LogConfig)
ctx = log2.ContextWithLogger(ctx, logger)
logger := log.Configure(os.Stderr, cli.LogConfig)
ctx = log.ContextWithLogger(ctx, logger)

// Handle signals.
sigch := make(chan os.Signal, 1)
Expand All @@ -86,6 +93,6 @@ func main() {

func makeDialer[Client rpc.Pingable](newClient func(connect.HTTPClient, string, ...connect.ClientOption) Client) func() (Client, error) {
return func() (Client, error) {
return rpc.Dial(newClient, cli.Endpoint.String(), log2.Error), nil
return rpc.Dial(newClient, cli.Endpoint.String(), log.Error), nil
}
}
5 changes: 5 additions & 0 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,24 @@ require (
github.com/alecthomas/errors v0.4.0 // indirect
github.com/alecthomas/participle/v2 v2.0.0 // indirect
github.com/alecthomas/types v0.7.1 // indirect
github.com/alessio/shellescape v1.4.1 // indirect
github.com/bufbuild/connect-go v1.8.0 // indirect
github.com/bufbuild/connect-grpcreflect-go v1.1.0 // indirect
github.com/bufbuild/connect-opentelemetry-go v0.3.0 // indirect
github.com/danieljoos/wincred v1.1.0 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/iancoleman/strcase v0.2.0 // indirect
github.com/jpillora/backoff v1.0.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/oklog/ulid/v2 v2.1.0 // indirect
github.com/rs/cors v1.9.0 // indirect
github.com/swaggest/jsonschema-go v0.3.59 // indirect
github.com/swaggest/refl v1.2.0 // indirect
github.com/zalando/go-keyring v0.2.1 // indirect
go.opentelemetry.io/otel v1.16.0 // indirect
go.opentelemetry.io/otel/metric v1.16.0 // indirect
go.opentelemetry.io/otel/trace v1.16.0 // indirect
Expand Down
Loading

0 comments on commit 01b0e76

Please sign in to comment.