Skip to content

Commit

Permalink
Correct updating of config file when (un)registering clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
Westwooo committed Dec 10, 2024
1 parent b18899e commit 4211cb6
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 84 deletions.
68 changes: 26 additions & 42 deletions src/cli/cbenv_register.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use crate::config::{CapellaOrganizationConfig, ClusterConfig, ShellConfig, DEFAULT_KV_BATCH_SIZE};
use crate::state::State;
use std::fs;
use std::sync::{Arc, Mutex, MutexGuard};

use crate::cli::error::generic_error;
use crate::cli::util::get_username_and_password;
use crate::cli::util::{get_username_and_password, read_config_file, update_config_file};
use crate::config::{ClusterConfig, DEFAULT_KV_BATCH_SIZE};
use crate::state::State;
use crate::{
ClusterTimeouts, RemoteCluster, RemoteClusterResources, RemoteClusterType, RustTlsConfig,
};
Expand All @@ -13,6 +10,7 @@ use nu_engine::CallExt;
use nu_protocol::engine::{Command, EngineState, Stack};
use nu_protocol::Value::Nothing;
use nu_protocol::{Category, PipelineData, ShellError, Signature, Span, SyntaxShape};
use std::sync::{Arc, Mutex, MutexGuard};

#[derive(Clone)]
pub struct CbEnvRegister {
Expand Down Expand Up @@ -151,7 +149,7 @@ fn clusters_register(
.get_flag(engine_state, stack, "tls-accept-all-certs")?
.unwrap_or(true);
let cert_path = call.get_flag(engine_state, stack, "tls-cert-path")?;
let save = call.get_flag(engine_state, stack, "save")?.unwrap_or(false);
let save = call.has_flag(engine_state, stack, "save")?;
let capella = call.get_flag(engine_state, stack, "capella-organization")?;
let project = call.get_flag(engine_state, stack, "project")?;
let display_name = call.get_flag(engine_state, stack, "display-name")?;
Expand Down Expand Up @@ -193,10 +191,10 @@ fn clusters_register(
);

let mut guard = state.lock().unwrap();
guard.add_cluster(identifier, cluster)?;
guard.add_cluster(identifier.clone(), cluster)?;

if save {
update_config_file(&mut guard, call.head)?;
save_new_cluster_config(&mut guard, call.head, identifier)?;
}

Ok(PipelineData::Value(
Expand All @@ -207,42 +205,28 @@ fn clusters_register(
))
}

pub fn update_config_file(guard: &mut MutexGuard<State>, span: Span) -> Result<(), ShellError> {
let path = match guard.config_path() {
Some(p) => p,
None => {
return Err(generic_error(
"A config path must be discoverable to save config",
None,
span,
));
}
};
let mut cluster_configs = Vec::new();
for (identifier, cluster) in guard.clusters() {
cluster_configs.push(ClusterConfig::from((identifier.clone(), cluster)))
}
let mut capella_configs = Vec::new();
for (identifier, c) in guard.capella_orgs() {
capella_configs.push(CapellaOrganizationConfig::new(
identifier.clone(),
c.secret_key(),
c.access_key(),
Some(c.timeout()),
c.default_project(),
fn save_new_cluster_config(
guard: &mut MutexGuard<State>,
span: Span,
identifier: String,
) -> Result<(), ShellError> {
let mut config = read_config_file(guard, span)?;
let clusters = config.clusters_mut();

if clusters.iter().any(|c| c.identifier() == identifier) {
return Err(generic_error(
format!(
"failed to update config file: cluster with identifier {} already exists",
identifier
),
None,
))
span,
));
}

let config = ShellConfig::new_from_clusters(cluster_configs, capella_configs);
let new_cluster = guard.clusters().get(&identifier).unwrap();

fs::write(
path,
config
.to_str()
.map_err(|e| generic_error(format!("Failed to write config file {}", e), None, span))?,
)
.map_err(|e| generic_error(format!("Failed to write config file {}", e), None, span))?;
clusters.push(ClusterConfig::from((identifier.clone(), new_cluster)));

Ok(())
update_config_file(guard, span, config)
}
31 changes: 26 additions & 5 deletions src/cli/cbenv_unregister.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::cli::cbenv_register::update_config_file;
use crate::cli::util::{read_config_file, update_config_file};
use crate::state::State;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, MutexGuard};

use crate::cli::error::{cluster_not_found_error, generic_error};
use nu_engine::command_prelude::Call;
use nu_engine::CallExt;
use nu_protocol::engine::{Command, EngineState, Stack};
use nu_protocol::Value::Nothing;
use nu_protocol::{Category, PipelineData, ShellError, Signature, SyntaxShape};
use nu_protocol::{Category, PipelineData, ShellError, Signature, Span, SyntaxShape};

#[derive(Clone)]
pub struct CbEnvUnregister {
Expand Down Expand Up @@ -63,7 +63,7 @@ fn clusters_unregister(
_input: PipelineData,
) -> Result<PipelineData, ShellError> {
let identifier: String = call.req(engine_state, stack, 0)?;
let save = call.get_flag(engine_state, stack, "save")?.unwrap_or(false);
let save = call.has_flag(engine_state, stack, "save")?;

let mut guard = state.lock().unwrap();
if guard.active() == identifier {
Expand All @@ -79,7 +79,7 @@ fn clusters_unregister(
};

if save {
update_config_file(&mut guard, call.head)?;
remove_cluster_config(&mut guard, call.head, identifier)?;
};

Ok(PipelineData::Value(
Expand All @@ -89,3 +89,24 @@ fn clusters_unregister(
None,
))
}

fn remove_cluster_config(
guard: &mut MutexGuard<State>,
span: Span,
identifier: String,
) -> Result<(), ShellError> {
let mut config = read_config_file(guard, span)?;
let clusters = config.clusters_mut();

if let Some(cluster_index) = clusters.iter().position(|c| c.identifier() == identifier) {
clusters.remove(cluster_index);
} else {
return Err(generic_error(
format!("cluster {} not in config file", identifier),
None,
span,
));
}

update_config_file(guard, span, config)
}
57 changes: 57 additions & 0 deletions src/cli/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ use crate::cli::generic_error;
use crate::cli::CBShellError::ClusterNotFound;
use crate::client::cloud_json::Cluster;
use crate::client::CapellaClient;
use crate::config::ShellConfig;
use crate::state::State;
use crate::{read_input, RemoteCluster, RemoteClusterType};
use log::debug;
use nu_engine::command_prelude::Call;
use nu_engine::CallExt;
use nu_protocol::ast::PathMember;
Expand All @@ -18,6 +20,7 @@ use nu_protocol::{Record, Signals};
use nu_utils::SharedCow;
use num_traits::cast::ToPrimitive;
use regex::Regex;
use std::fs;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration;

Expand Down Expand Up @@ -526,6 +529,60 @@ pub fn get_username_and_password(
Ok((username, password))
}

pub fn read_config_file(
guard: &mut MutexGuard<State>,
span: Span,
) -> Result<ShellConfig, ShellError> {
let path = match guard.config_path() {
Some(p) => p,
None => {
return Err(generic_error(
"A config path must be discoverable to save config",
None,
span,
));
}
};

let config = fs::read(path)
.map_err(|e| generic_error(format!("Could not read current config: {}", e), None, span))?;

let shell_config = ShellConfig::from_str(std::str::from_utf8(&config).unwrap());

debug!("config read from {:?} - {:?}", path, shell_config);

Ok(shell_config)
}

pub fn update_config_file(
guard: &mut MutexGuard<State>,
span: Span,
config: ShellConfig,
) -> Result<(), ShellError> {
let path = match guard.config_path() {
Some(p) => p,
None => {
return Err(generic_error(
"A config path must be discoverable to save config",
None,
span,
));
}
};

debug!("updating config at {:?} to {:?}", path, config);

fs::write(
path,
config
.to_str()
.map_err(|e| generic_error(format!("Failed to write config file {}", e), None, span))?,
)
.map_err(|e| generic_error(format!("Failed to write config file {}", e), None, span))?;

Ok(())
}

#[cfg(test)]
mod tests {
use crate::cli::util::duration_to_golang_string;
Expand Down
89 changes: 60 additions & 29 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::remote_cluster::{RemoteCluster, RemoteClusterType};
use crate::remote_cluster::{ClusterTimeouts, RemoteCluster, RemoteClusterType};
use crate::state::Provider;
use log::debug;
use log::error;
Expand Down Expand Up @@ -251,25 +251,6 @@ pub struct CapellaOrganizationConfig {
}

impl CapellaOrganizationConfig {
pub fn new(
identifier: String,
secret_key: String,
access_key: String,
management_timeout: Option<Duration>,
default_project: Option<String>,
api_endpoint: Option<String>,
) -> Self {
Self {
identifier,
credentials: OrganizationCredentials {
access_key,
secret_key,
},
management_timeout,
default_project,
api_endpoint,
}
}
pub fn identifier(&self) -> String {
self.identifier.clone()
}
Expand Down Expand Up @@ -511,6 +492,55 @@ impl ClusterConfig {
}
}

impl From<ClusterTimeouts> for ClusterConfigTimeouts {
fn from(timeouts: ClusterTimeouts) -> Self {
let data_timeout = if timeouts.data_timeout() == DEFAULT_DATA_TIMEOUT {
None
} else {
Some(timeouts.data_timeout())
};

let query_timeout = if timeouts.query_timeout() == DEFAULT_QUERY_TIMEOUT {
None
} else {
Some(timeouts.query_timeout())
};

let analytics_timeout = if timeouts.analytics_timeout() == DEFAULT_ANALYTICS_TIMEOUT {
None
} else {
Some(timeouts.analytics_timeout())
};

let search_timeout = if timeouts.search_timeout() == DEFAULT_SEARCH_TIMEOUT {
None
} else {
Some(timeouts.search_timeout())
};

let management_timeout = if timeouts.management_timeout() == DEFAULT_MANAGEMENT_TIMEOUT {
None
} else {
Some(timeouts.management_timeout())
};

let transaction_timeout = if timeouts.transaction_timeout() == DEFAULT_TRANSACTION_TIMEOUT {
None
} else {
Some(timeouts.transaction_timeout())
};

Self {
data_timeout,
query_timeout,
analytics_timeout,
search_timeout,
management_timeout,
transaction_timeout,
}
}
}

impl From<(String, &RemoteCluster)> for ClusterConfig {
fn from(cluster: (String, &RemoteCluster)) -> Self {
let cloud = cluster.1.capella_org();
Expand All @@ -529,28 +559,27 @@ impl From<(String, &RemoteCluster)> for ClusterConfig {
}
};

let kv_batch_size = if cluster.1.kv_batch_size() == DEFAULT_KV_BATCH_SIZE {
None
} else {
Some(cluster.1.kv_batch_size())
};

Self {
identifier: cluster.0,
conn_string: cluster.1.hostnames().join(","),
default_collection: cluster.1.active_collection(),
default_scope: cluster.1.active_scope(),
default_bucket: cluster.1.active_bucket(),
timeouts: ClusterConfigTimeouts {
data_timeout: Some(cluster.1.timeouts().data_timeout()),
query_timeout: Some(cluster.1.timeouts().query_timeout()),
analytics_timeout: Some(cluster.1.timeouts().analytics_timeout()),
search_timeout: Some(cluster.1.timeouts().search_timeout()),
management_timeout: Some(cluster.1.timeouts().management_timeout()),
transaction_timeout: Some(cluster.1.timeouts().transaction_timeout()),
},
timeouts: ClusterConfigTimeouts::from(cluster.1.timeouts()),
tls: tls_config,
credentials: ClusterCredentials {
username: Some(cluster.1.username().to_string()),
password: Some(cluster.1.password().to_string()),
},
capella_org: cloud,
project: cluster.1.project(),
kv_batch_size: Some(cluster.1.kv_batch_size()),
kv_batch_size,
display_name: cluster.1.display_name(),
// This is a config option for dev ony so we won't want to write to file
cluster_type: None,
Expand All @@ -562,9 +591,11 @@ impl From<(String, &RemoteCluster)> for ClusterConfig {
pub struct OrganizationCredentials {
#[serde(default)]
#[serde(rename(deserialize = "access-key", serialize = "access-key"))]
#[serde(skip_serializing_if = "String::is_empty")]
access_key: String,
#[serde(default)]
#[serde(rename(deserialize = "secret-key", serialize = "secret-key"))]
#[serde(skip_serializing_if = "String::is_empty")]
secret_key: String,
}

Expand Down
8 changes: 0 additions & 8 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,6 @@ impl RemoteCapellaOrganization {
}
}

pub fn secret_key(&self) -> String {
self.secret_key.clone()
}

pub fn access_key(&self) -> String {
self.access_key.clone()
}

pub fn client(&self) -> Arc<CapellaClient> {
let mut c = self.client.lock().unwrap();
if c.is_none() {
Expand Down

0 comments on commit 4211cb6

Please sign in to comment.