Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow config override by envvar #763

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cmd/nvidia-container-runtime-hook/hook_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ func getDefaultHookConfig() (HookConfig, error) {
}

// loadConfig loads the required paths for the hook config.
func loadConfig() (*config.Config, error) {
func (a *app) loadConfig() (*config.Config, error) {
var configPaths []string
var required bool
if len(*configflag) != 0 {
configPaths = append(configPaths, *configflag)
if len(a.configFile) != 0 {
configPaths = append(configPaths, a.configFile)
required = true
} else {
configPaths = append(configPaths, path.Join(driverPath, configPath), configPath)
Expand All @@ -56,8 +56,8 @@ func loadConfig() (*config.Config, error) {
return config.GetDefault()
}

func getHookConfig() (*HookConfig, error) {
cfg, err := loadConfig()
func (a *app) getHookConfig() (*HookConfig, error) {
cfg, err := a.loadConfig()
if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err)
}
Expand Down
7 changes: 4 additions & 3 deletions cmd/nvidia-container-runtime-hook/hook_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,17 @@ func TestGetHookConfig(t *testing.T) {
if len(filename) > 0 {
os.Remove(filename)
}
configflag = nil
}()

a := &app{}

if tc.lines != nil {
configFile, err := os.CreateTemp("", "*.toml")
require.NoError(t, err)
defer configFile.Close()

filename = configFile.Name()
configflag = &filename
a.configFile = filename

for _, line := range tc.lines {
_, err := configFile.WriteString(fmt.Sprintf("%s\n", line))
Expand All @@ -91,7 +92,7 @@ func TestGetHookConfig(t *testing.T) {

var config HookConfig
getHookConfig := func() {
c, _ := getHookConfig()
c, _ := a.getHookConfig()
config = *c
}

Expand Down
123 changes: 68 additions & 55 deletions cmd/nvidia-container-runtime-hook/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package main

import (
"flag"
"errors"
"fmt"
"log"
"os"
Expand All @@ -13,29 +13,26 @@ import (
"strings"
"syscall"

cli "github.com/urfave/cli/v2"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
)

var (
debugflag = flag.Bool("debug", false, "enable debug output")
versionflag = flag.Bool("version", false, "enable version output")
configflag = flag.String("config", "", "configuration file")
)

func exit() {
func (a *app) recoverIfRequired() error {
if err := recover(); err != nil {
if _, ok := err.(runtime.Error); ok {
rerr, ok := err.(runtime.Error)
if ok {
log.Println(err)
}
if *debugflag {
if a.isDebug {
log.Printf("%s", debug.Stack())
}
os.Exit(1)
return rerr
}
os.Exit(0)
return nil
}

func getCLIPath(config config.ContainerCLIConfig) string {
Expand Down Expand Up @@ -63,27 +60,27 @@ func getRootfsPath(config containerConfig) string {
return rootfs
}

func doPrestart() {
var err error

defer exit()
func (a *app) doPrestart() (rerr error) {
defer func() {
rerr = errors.Join(rerr, a.recoverIfRequired())
}()
log.SetFlags(0)

hook, err := getHookConfig()
hook, err := a.getHookConfig()
if err != nil || hook == nil {
log.Panicln("error getting hook config:", err)
return fmt.Errorf("error getting hook config: %w", err)
}
cli := hook.NVIDIAContainerCLIConfig

container := getContainerConfig(*hook)
nvidia := container.Nvidia
if nvidia == nil {
// Not a GPU container, nothing to do.
return
return nil
}

if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.")
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
}

rootfs := getRootfsPath(container)
Expand All @@ -101,7 +98,7 @@ func doPrestart() {
if cli.NoPivot {
args = append(args, "--no-pivot")
}
if *debugflag {
if a.isDebug {
args = append(args, "--debug=/dev/stderr")
} else if cli.Debug != "" {
args = append(args, fmt.Sprintf("--debug=%s", cli.Debug))
Expand Down Expand Up @@ -149,45 +146,61 @@ func doPrestart() {

env := append(os.Environ(), cli.Environment...)
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection?
err = syscall.Exec(args[0], args, env)
log.Panicln("exec failed:", err)
return syscall.Exec(args[0], args, env)
}

func usage() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nCommands:\n")
fmt.Fprintf(os.Stderr, " prestart\n run the prestart hook\n")
fmt.Fprintf(os.Stderr, " poststart\n no-op\n")
fmt.Fprintf(os.Stderr, " poststop\n no-op\n")
type options struct {
isDebug bool
configFile string
}
type app struct {
options
}

func main() {
flag.Usage = usage
flag.Parse()

if *versionflag {
fmt.Printf("%v version %v\n", "NVIDIA Container Runtime Hook", info.GetVersionString())
return
}

args := flag.Args()
if len(args) == 0 {
flag.Usage()
os.Exit(2)
}

switch args[0] {
case "prestart":
doPrestart()
os.Exit(0)
case "poststart":
fallthrough
case "poststop":
os.Exit(0)
default:
flag.Usage()
os.Exit(2)
a := &app{}
// Create the top-level CLI
c := cli.NewApp()
c.Name = "NVIDIA Container Runtime Hook"
c.Version = info.GetVersionString()

c.Flags = []cli.Flag{
&cli.BoolFlag{
Name: "debug",
Destination: &a.isDebug,
Usage: "Enabled debug output",
},
&cli.StringFlag{
Name: "config",
Destination: &a.configFile,
Usage: "The path to the configuration file to use",
EnvVars: []string{config.FilePathOverrideEnvVar},
},
}

c.Commands = []*cli.Command{
{
Name: "prestart",
Usage: "run the prestart hook",
Action: func(ctx *cli.Context) error {
return a.doPrestart()
},
},
{
Name: "poststart",
Aliases: []string{"poststop"},
Usage: "no-op",
Action: func(ctx *cli.Context) error {
return nil
},
},
}
c.DefaultCommand = "prestart"

// Run the CLI
err := c.Run(os.Args)
if err != nil {
os.Exit(1)
}
}

Expand Down
16 changes: 11 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ import (
)

const (
configOverride = "XDG_CONFIG_HOME"
configFilePath = "nvidia-container-runtime/config.toml"
FilePathOverrideEnvVar = "NVCTK_CONFIG_FILE_PATH"
RelativeFilePath = "nvidia-container-runtime/config.toml"

configRootOverride = "XDG_CONFIG_HOME"

nvidiaCTKExecutable = "nvidia-ctk"
nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk"
Expand Down Expand Up @@ -71,11 +73,15 @@ type Config struct {

// GetConfigFilePath returns the path to the config file for the configured system
func GetConfigFilePath() string {
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 {
return filepath.Join(XDGConfigDir, configFilePath)
if configFilePathOverride := os.Getenv(FilePathOverrideEnvVar); configFilePathOverride != "" {
return configFilePathOverride
}
configRoot := "/etc"
if XDGConfigDir := os.Getenv(configRootOverride); len(XDGConfigDir) != 0 {
configRoot = XDGConfigDir
}

return filepath.Join("/etc", configFilePath)
return filepath.Join(configRoot, RelativeFilePath)
}

// GetConfig sets up the config struct. Values are read from a toml file
Expand Down
21 changes: 19 additions & 2 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,26 @@ import (

func TestGetConfigWithCustomConfig(t *testing.T) {
testDir := t.TempDir()
t.Setenv(configOverride, testDir)
t.Setenv(configRootOverride, testDir)

filename := filepath.Join(testDir, configFilePath)
filename := filepath.Join(testDir, RelativeFilePath)

// By default debug is disabled
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")

require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766))
require.NoError(t, os.WriteFile(filename, contents, 0600))

cfg, err := GetConfig()
require.NoError(t, err)
require.Equal(t, "/nvidia-container-toolkit.log", cfg.NVIDIAContainerRuntimeConfig.DebugFilePath)
}

func TestGetConfigWithConfigFilePathOverride(t *testing.T) {
testDir := t.TempDir()
filename := filepath.Join(testDir, RelativeFilePath)

t.Setenv(FilePathOverrideEnvVar, filename)

// By default debug is disabled
contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"")
Expand Down
5 changes: 0 additions & 5 deletions tools/container/toolkit/executable.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ type executable struct {
target executableTarget
env map[string]string
preLines []string
argLines []string
}

// install installs an executable component of the NVIDIA container toolkit. The source executable
Expand Down Expand Up @@ -128,10 +127,6 @@ func (e executable) writeWrapperTo(wrapper io.Writer, destFolder string, dotfile
// Add the call to the target executable
fmt.Fprintf(wrapper, "%s \\\n", dotfileName)

// Insert additional lines in the `arg` list
for _, line := range e.argLines {
fmt.Fprintf(wrapper, "\t%s \\\n", r.apply(line))
}
// Add the script arguments "$@"
fmt.Fprintln(wrapper, "\t\"$@\"")

Expand Down
17 changes: 0 additions & 17 deletions tools/container/toolkit/executable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,6 @@ func TestWrapper(t *testing.T) {
"",
},
},
{
e: executable{
argLines: []string{
"argline1",
"argline2",
},
},
expectedLines: []string{
shebang,
"PATH=/dest/folder:$PATH \\",
"source.real \\",
"\targline1 \\",
"\targline2 \\",
"\t\"$@\"",
"",
},
},
}

for i, tc := range testCases {
Expand Down
13 changes: 7 additions & 6 deletions tools/container/toolkit/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"path/filepath"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/tools/container/operator"
)

Expand All @@ -29,10 +30,10 @@ const (

// installContainerRuntimes sets up the NVIDIA container runtimes, copying the executables
// and implementing the required wrapper
func installContainerRuntimes(toolkitDir string, driverRoot string) error {
func installContainerRuntimes(toolkitDir string, configFilePath string) error {
runtimes := operator.GetRuntimes()
for _, runtime := range runtimes {
r := newNvidiaContainerRuntimeInstaller(runtime.Path)
r := newNvidiaContainerRuntimeInstaller(runtime.Path, configFilePath)

_, err := r.install(toolkitDir)
if err != nil {
Expand All @@ -46,17 +47,17 @@ func installContainerRuntimes(toolkitDir string, driverRoot string) error {
// This installer will copy the specified source executable to the toolkit directory.
// The executable is copied to a file with the same name as the source, but with a ".real" suffix and a wrapper is
// created to allow for the configuration of the runtime environment.
func newNvidiaContainerRuntimeInstaller(source string) *executable {
func newNvidiaContainerRuntimeInstaller(source string, configFilePath string) *executable {
wrapperName := filepath.Base(source)
dotfileName := wrapperName + ".real"
target := executableTarget{
dotfileName: dotfileName,
wrapperName: wrapperName,
}
return newRuntimeInstaller(source, target, nil)
return newRuntimeInstaller(source, target, configFilePath, nil)
}

func newRuntimeInstaller(source string, target executableTarget, env map[string]string) *executable {
func newRuntimeInstaller(source string, target executableTarget, configFilePath string, env map[string]string) *executable {
preLines := []string{
"",
"cat /proc/modules | grep -e \"^nvidia \" >/dev/null 2>&1",
Expand All @@ -68,7 +69,7 @@ func newRuntimeInstaller(source string, target executableTarget, env map[string]
}

runtimeEnv := make(map[string]string)
runtimeEnv["XDG_CONFIG_HOME"] = filepath.Join(destDirPattern, ".config")
runtimeEnv[config.FilePathOverrideEnvVar] = configFilePath
for k, v := range env {
runtimeEnv[k] = v
}
Expand Down
Loading