Skip to content

Commit

Permalink
software: Fix race conditions in migration logger
Browse files Browse the repository at this point in the history
  • Loading branch information
ish-hcc committed Nov 18, 2024
1 parent e4d0151 commit 1c2cb22
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 53 deletions.
29 changes: 12 additions & 17 deletions lib/software/configCopier.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func sudoWrapper(cmd string, password string) string {
return fmt.Sprintf("echo '%s' | sudo -S sh -c '%s'", password, strings.Replace(cmd, "'", "'\"'\"'", -1))
}

func findCertKeyPaths(client *ssh.Client, filePath string) ([]string, error) {
func findCertKeyPaths(client *ssh.Client, filePath string, migrationLogger *Logger) ([]string, error) {
migrationLogger.Printf(INFO, "Finding certificate and key paths in file: %s\n", filePath)

cmd := fmt.Sprintf(`
Expand Down Expand Up @@ -104,7 +104,7 @@ func findCertKeyPaths(client *ssh.Client, filePath string) ([]string, error) {
return validPaths, nil
}

func copyFile(sourceClient *ssh.Client, targetClient *ssh.Client, filePath string) error {
func copyFile(sourceClient *ssh.Client, targetClient *ssh.Client, filePath string, migrationLogger *Logger) error {
migrationLogger.Printf(INFO, "Starting file copy process for: %s\n", filePath)

sourcePassword := sourceClient.SSHTarget.Password
Expand Down Expand Up @@ -269,7 +269,7 @@ EOL
return nil
}

func parseConfigLine(line string) ConfigFile {
func parseConfigLine(line string, migrationLogger *Logger) ConfigFile {
migrationLogger.Printf(DEBUG, "Parsing config line: %s\n", line)

parts := strings.SplitN(line, " [", 2)
Expand All @@ -285,7 +285,7 @@ func parseConfigLine(line string) ConfigFile {
return conf
}

func findConfigs(client *ssh.Client, packageName string) ([]ConfigFile, error) {
func findConfigs(client *ssh.Client, packageName string, migrationLogger *Logger) ([]ConfigFile, error) {
migrationLogger.Printf(INFO, "Starting config search for package: %s\n", packageName)

cmd := `#!/bin/sh
Expand Down Expand Up @@ -441,7 +441,7 @@ rm -f "$tmp_result"`
migrationLogger.Printf(DEBUG, "Processing found config files\n")
for _, line := range strings.Split(string(output), "\n") {
if line = strings.TrimSpace(line); line != "" {
conf := parseConfigLine(line)
conf := parseConfigLine(line, migrationLogger)
if !seen[conf.Path] {
seen[conf.Path] = true
configs = append(configs, conf)
Expand All @@ -454,7 +454,7 @@ rm -f "$tmp_result"`
return configs, nil
}

func copyConfigFiles(sourceClient *ssh.Client, targetClient *ssh.Client, configs []ConfigFile) error {
func copyConfigFiles(sourceClient *ssh.Client, targetClient *ssh.Client, configs []ConfigFile, migrationLogger *Logger) error {
migrationLogger.Printf(INFO, "Starting config files copy process for %d files\n", len(configs))

for i, conf := range configs {
Expand All @@ -466,14 +466,14 @@ func copyConfigFiles(sourceClient *ssh.Client, targetClient *ssh.Client, configs
}

migrationLogger.Printf(INFO, "Copying config file: %s [Status: %s]\n", conf.Path, conf.Status)
err := copyFile(sourceClient, targetClient, conf.Path)
err := copyFile(sourceClient, targetClient, conf.Path, migrationLogger)
if err != nil {
migrationLogger.Printf(ERROR, "Failed to copy config file %s: %v\n", conf.Path, err)
return fmt.Errorf("failed to copy file %s: %v", conf.Path, err)
}

migrationLogger.Printf(DEBUG, "Searching for associated cert/key files for: %s\n", conf.Path)
certKeyPaths, err := findCertKeyPaths(sourceClient, conf.Path)
certKeyPaths, err := findCertKeyPaths(sourceClient, conf.Path, migrationLogger)
if err != nil {
migrationLogger.Printf(WARN, "Error finding cert/key paths for %s: %v\n", conf.Path, err)
continue
Expand All @@ -483,7 +483,7 @@ func copyConfigFiles(sourceClient *ssh.Client, targetClient *ssh.Client, configs
migrationLogger.Printf(INFO, "Found %d cert/key files for %s\n", len(certKeyPaths), conf.Path)
for j, path := range certKeyPaths {
migrationLogger.Printf(INFO, "Copying cert/key file %d/%d: %s\n", j+1, len(certKeyPaths), path)
err := copyFile(sourceClient, targetClient, path)
err := copyFile(sourceClient, targetClient, path, migrationLogger)
if err != nil {
migrationLogger.Printf(WARN, "Failed to copy cert/key file %s: %v\n", path, err)
continue
Expand All @@ -501,12 +501,7 @@ func copyConfigFiles(sourceClient *ssh.Client, targetClient *ssh.Client, configs
return nil
}

func configCopier(sourceClient *ssh.Client, targetClient *ssh.Client, packageName, uuid string) error {
if err := initLoggerWithUUID(uuid); err != nil {
return fmt.Errorf("failed to initialize logger: %v", err)
}
defer migrationLogger.Close()

func configCopier(sourceClient *ssh.Client, targetClient *ssh.Client, packageName, uuid string, migrationLogger *Logger) error {
migrationLogger.Printf(INFO, "Starting config copier for package: %s (UUID: %s)\n", packageName, uuid)

if sourceClient == nil || targetClient == nil {
Expand All @@ -520,7 +515,7 @@ func configCopier(sourceClient *ssh.Client, targetClient *ssh.Client, packageNam
}

migrationLogger.Printf(INFO, "Finding configuration files for package: %s\n", packageName)
configs, err := findConfigs(sourceClient, packageName)
configs, err := findConfigs(sourceClient, packageName, migrationLogger)
if err != nil {
migrationLogger.Printf(ERROR, "Failed to find configs for package %s: %v\n", packageName, err)
return fmt.Errorf("failed to find configs: %v", err)
Expand All @@ -533,7 +528,7 @@ func configCopier(sourceClient *ssh.Client, targetClient *ssh.Client, packageNam

migrationLogger.Printf(INFO, "Found %d configuration files for package %s\n", len(configs), packageName)

if err := copyConfigFiles(sourceClient, targetClient, configs); err != nil {
if err := copyConfigFiles(sourceClient, targetClient, configs, migrationLogger); err != nil {
migrationLogger.Printf(ERROR, "Failed to copy config files for package %s: %v\n", packageName, err)
return fmt.Errorf("failed to copy config files: %v", err)
}
Expand Down
18 changes: 15 additions & 3 deletions lib/software/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,21 @@ func MigrateSoftware(executionID string, executionList *[]model.MigrationSoftwar
return
}

err = configCopier(s, t, execution.SoftwareName, executionID)
migrationLogger, err := initLoggerWithUUID(executionID)
if err != nil {
errMsg := fmt.Sprintf("failed to initialize logger: %v", err)
logger.Println(logger.ERROR, true, "migrateSoftware: ExecutionID="+executionID+
", InstallType=package, SoftwareID="+execution.SoftwareID+", Error="+errMsg)
updateStatus(i, "failed", errMsg, false)
}

err = configCopier(s, t, execution.SoftwareName, executionID, migrationLogger)
if err != nil {
logger.Println(logger.ERROR, true, "migrateSoftware: ExecutionID="+executionID+
", InstallType=package, SoftwareID="+execution.SoftwareID+", Error="+err.Error())
updateStatus(i, "failed", err.Error(), false)

migrationLogger.Close()
return
}

Expand All @@ -105,6 +114,7 @@ func MigrateSoftware(executionID string, executionList *[]model.MigrationSoftwar
", SoftwareID="+execution.SoftwareID+", Error="+err.Error())
updateStatus(i, "failed", err.Error(), false)

migrationLogger.Close()
return
}

Expand All @@ -118,22 +128,24 @@ func MigrateSoftware(executionID string, executionList *[]model.MigrationSoftwar
Status: "Custom",
})
}
err = copyConfigFiles(s, t, customConfigs)
err = copyConfigFiles(s, t, customConfigs, migrationLogger)
if err != nil {
logger.Println(logger.ERROR, true, "migrateSoftware: ExecutionID="+executionID+
", SoftwareID="+execution.SoftwareID+", Error="+err.Error())
updateStatus(i, "failed", err.Error(), false)

migrationLogger.Close()
return
}
}

err = serviceMigrator(s, t, execution.SoftwareName, executionID)
err = serviceMigrator(s, t, execution.SoftwareName, executionID, migrationLogger)
if err != nil {
logger.Println(logger.ERROR, true, "migrateSoftware: ExecutionID="+executionID+
", InstallType=package, SoftwareID="+execution.SoftwareID+", Error="+err.Error())
updateStatus(i, "failed", err.Error(), false)

migrationLogger.Close()
return
}

Expand Down
11 changes: 3 additions & 8 deletions lib/software/listenPortsValidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func findPIDByInode(client *ssh.Client, password, inode string, pids []string) (
return 0, fmt.Errorf("failed to find PID for inode: %s", inode)
}

func compareServicePorts(sourceClient, targetClient *ssh.Client, serviceName string) error {
func compareServicePorts(sourceClient, targetClient *ssh.Client, serviceName string, migrationLogger *Logger) error {
migrationLogger.Printf(INFO, "Starting service migration for package: %s\n", serviceName)

migrationLogger.Printf(DEBUG, "Detecting source PID\n")
Expand Down Expand Up @@ -366,13 +366,8 @@ func readProgramName(client *ssh.Client, password string, pid int) string {
return strings.ReplaceAll(strings.TrimSpace(out.String()), "\x00", " ")
}

func listenPortsValidator(sourceClient *ssh.Client, targetClient *ssh.Client, serviceName, uuid string) error {
if err := initLoggerWithUUID(uuid); err != nil {
return fmt.Errorf("failed to initialize logger: %v", err)
}
defer migrationLogger.Close()

err := compareServicePorts(sourceClient, targetClient, serviceName)
func listenPortsValidator(sourceClient *ssh.Client, targetClient *ssh.Client, serviceName string, migrationLogger *Logger) error {
err := compareServicePorts(sourceClient, targetClient, serviceName, migrationLogger)
if err != nil {
migrationLogger.Printf(ERROR, "Error during comparison: %v\n", err)
return err
Expand Down
14 changes: 6 additions & 8 deletions lib/software/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ type Logger struct {
fpLog *os.File
}

var migrationLogger *Logger

func initLoggerWithUUID(uuid string) error {
migrationLogger = &Logger{}
func initLoggerWithUUID(uuid string) (*Logger, error) {
var migrationLogger = &Logger{}

logPath := filepath.Join(
config.CMGrasshopperConfig.CMGrasshopper.Software.LogFolder,
Expand Down Expand Up @@ -170,25 +168,25 @@ func (l *Logger) Panicf(logLevel string, format string, a ...any) {
}

// Init : Initialize log file
func (l *Logger) Init(logPath, logFileName string) error {
func (l *Logger) Init(logPath, logFileName string) (*Logger, error) {
var err error

if _, err = os.Stat(logPath); os.IsNotExist(err) {
err = fileutil.CreateDirIfNotExist(logPath)
if err != nil {
return err
return nil, err
}
}

l.fpLog, err = os.OpenFile(filepath.Join(logPath, logFileName), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0666)
if err != nil {
l.logger = log.New(io.Writer(os.Stdout), "", log.Ldate|log.Ltime)
return err
return nil, err
}

l.logger = log.New(io.Writer(l.fpLog), "", log.Ldate|log.Ltime)

return nil
return l, nil
}

// Close : Close log file
Expand Down
29 changes: 12 additions & 17 deletions lib/software/serviceMigrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func cleanServiceName(name string) string {
return cleaned
}

func getRealServiceName(client *ssh.Client, name string) (string, error) {
func getRealServiceName(client *ssh.Client, name string, migrationLogger *Logger) (string, error) {
session, err := client.NewSession()
if err != nil {
migrationLogger.Printf(ERROR, "Failed to create SSH session: %v\n", err)
Expand Down Expand Up @@ -89,7 +89,7 @@ func getRealServiceName(client *ssh.Client, name string) (string, error) {
return name, nil
}

func getServiceInfo(client *ssh.Client, name string) ServiceInfo {
func getServiceInfo(client *ssh.Client, name string, migrationLogger *Logger) ServiceInfo {
migrationLogger.Printf(INFO, "Getting service info for: %s\n", name)

info := ServiceInfo{
Expand Down Expand Up @@ -137,7 +137,7 @@ func getServiceInfo(client *ssh.Client, name string) ServiceInfo {
}
}
case "alias":
if realName, err := getRealServiceName(client, name); err == nil {
if realName, err := getRealServiceName(client, name, migrationLogger); err == nil {
info.Name = realName
newSession, err := client.NewSession()
if err == nil {
Expand Down Expand Up @@ -247,7 +247,7 @@ func findPackageRelatedServices(client *ssh.Client, packageName string) ([]strin
return services, nil
}

func findServices(client *ssh.Client, name string) []ServiceInfo {
func findServices(client *ssh.Client, name string, migrationLogger *Logger) []ServiceInfo {
migrationLogger.Printf(INFO, "Finding services for package: %s\n", name)

var services []ServiceInfo
Expand All @@ -260,7 +260,7 @@ func findServices(client *ssh.Client, name string) []ServiceInfo {
if err != nil {
migrationLogger.Printf(WARN, "Failed to find services for package: %v\n", err)
migrationLogger.Printf(INFO, "Trying to find real service name for package: %s\n", name)
realName, err := getRealServiceName(client, name)
realName, err := getRealServiceName(client, name, migrationLogger)
if err != nil {
migrationLogger.Printf(WARN, "Failed to find real service name for package: %v\n", err)
}
Expand All @@ -277,15 +277,15 @@ func findServices(client *ssh.Client, name string) []ServiceInfo {
migrationLogger.Printf(INFO, "Found %d relevant services for package\n", len(filteredServices))

for _, serviceName := range filteredServices {
info := getServiceInfo(client, serviceName)
info := getServiceInfo(client, serviceName, migrationLogger)
services = append(services, info)
}

migrationLogger.Printf(INFO, "Completed service discovery with %d services\n", len(services))
return services
}

func getSystemType(client *ssh.Client) (SystemType, error) {
func getSystemType(client *ssh.Client, migrationLogger *Logger) (SystemType, error) {
session, err := client.NewSession()
if err != nil {
migrationLogger.Printf(ERROR, "Failed to create SSH session: %v\n", err)
Expand Down Expand Up @@ -325,23 +325,18 @@ func getSystemType(client *ssh.Client) (SystemType, error) {
}
}

func serviceMigrator(sourceClient *ssh.Client, targetClient *ssh.Client, packageName, uuid string) error {
if err := initLoggerWithUUID(uuid); err != nil {
return fmt.Errorf("failed to initialize logger: %v", err)
}
defer migrationLogger.Close()

func serviceMigrator(sourceClient *ssh.Client, targetClient *ssh.Client, packageName, uuid string, migrationLogger *Logger) error {
migrationLogger.Printf(INFO, "Starting service migration for package: %s (UUID: %s)\n", packageName, uuid)

migrationLogger.Printf(DEBUG, "Detecting source system type\n")
sourceType, err := getSystemType(sourceClient)
sourceType, err := getSystemType(sourceClient, migrationLogger)
if err != nil {
migrationLogger.Printf(ERROR, "Failed to detect source system type: %v\n", err)
return fmt.Errorf("failed to detect source system type: %v", err)
}

migrationLogger.Printf(DEBUG, "Detecting target system type\n")
targetType, err := getSystemType(targetClient)
targetType, err := getSystemType(targetClient, migrationLogger)
if err != nil {
migrationLogger.Printf(ERROR, "Failed to detect target system type: %v\n", err)
return fmt.Errorf("failed to detect target system type: %v", err)
Expand All @@ -352,7 +347,7 @@ func serviceMigrator(sourceClient *ssh.Client, targetClient *ssh.Client, package
return fmt.Errorf("system type mismatch: source=%v, target=%v", sourceType, targetType)
}

services := findServices(sourceClient, packageName)
services := findServices(sourceClient, packageName, migrationLogger)
if len(services) == 0 {
migrationLogger.Printf(WARN, "No services found for package: %s\n", packageName)
return nil
Expand Down Expand Up @@ -486,7 +481,7 @@ func serviceMigrator(sourceClient *ssh.Client, targetClient *ssh.Client, package
}
_ = session.Close()

err = listenPortsValidator(sourceClient, targetClient, service.Name, uuid)
err = listenPortsValidator(sourceClient, targetClient, service.Name, migrationLogger)
if err != nil {
issues = append(issues, err.Error())
}
Expand Down

0 comments on commit 1c2cb22

Please sign in to comment.