Skip to content

Commit

Permalink
Merge pull request #74 from arkedge/minor_refactors
Browse files Browse the repository at this point in the history
簡単なリファクタ
  • Loading branch information
KOBA789 authored Apr 22, 2024
2 parents bfd3dab + 49f70e4 commit 80a13f1
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 28 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions kble/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ clap.workspace = true
serde.workspace = true
serde_yaml = "0.9"
serde_with = "3.7"
tracing-subscriber.workspace = true
tracing.workspace = true
notalawyer.workspace = true
notalawyer-clap.workspace = true
28 changes: 28 additions & 0 deletions kble/src/app.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::{plug, spaghetti::Config};
use anyhow::Result;
use futures::future;
use futures::StreamExt;
use std::collections::HashMap;

pub async fn run(config: &Config) -> Result<()> {
let mut sinks = HashMap::new();
let mut streams = HashMap::new();
for (name, url) in config.plugs().iter() {
let (sink, stream) = plug::connect(url).await?;
sinks.insert(name.as_str(), sink);
streams.insert(name.as_str(), stream);
}
let mut edges = vec![];
for (stream_name, sink_name) in config.links().iter() {
let Some(stream) = streams.remove(stream_name.as_str()) else {
unreachable!("No such plug: {stream_name}");
};
let Some(sink) = sinks.remove(sink_name.as_str()) else {
unreachable!("No such plug or already used: {sink_name}");
};
let edge = stream.forward(sink);
edges.push(edge);
}
future::try_join_all(edges).await?;
Ok(())
}
21 changes: 18 additions & 3 deletions kble/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ use std::path::PathBuf;
use anyhow::{Context, Result};
use clap::Parser;
use notalawyer_clap::*;
use tracing_subscriber::{prelude::*, EnvFilter};

mod app;
mod plug;
mod spaghetti;

use spaghetti::{Config, Raw};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Expand All @@ -21,15 +25,26 @@ impl Args {
.open(&self.spaghetti)
.with_context(|| format!("Failed to open {:?}", &self.spaghetti))?;
let spagetthi_rdr = std::io::BufReader::new(spaghetti_file);
serde_yaml::from_reader(spagetthi_rdr)
.with_context(|| format!("Unable to parse {:?}", self.spaghetti))
let raw: Config<Raw> = serde_yaml::from_reader(spagetthi_rdr)
.with_context(|| format!("Unable to parse {:?}", self.spaghetti))?;
raw.validate()
.with_context(|| format!("Invalid configuration in {:?}", self.spaghetti))
}
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(std::io::stderr),
)
.with(EnvFilter::from_default_env())
.init();

let args = Args::parse_with_license_notice(include_notice!());
let config = args.load_spaghetti_config()?;
config.run().await?;
app::run(&config).await?;
Ok(())
}
110 changes: 85 additions & 25 deletions kble/src/spaghetti.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,74 @@
use anyhow::{anyhow, Result};
use std::collections::HashMap;

use anyhow::{anyhow, Result};
use futures::{future, StreamExt};
use serde::{Deserialize, Serialize};
use url::Url;

use crate::plug;

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Config {
pub struct Inner {
plugs: HashMap<String, Url>,
links: HashMap<String, String>,
}

impl Config {
pub async fn run(&self) -> Result<()> {
let mut sinks = HashMap::new();
let mut streams = HashMap::new();
for (name, url) in self.plugs.iter() {
let (sink, stream) = plug::connect(url).await?;
sinks.insert(name.as_str(), sink);
streams.insert(name.as_str(), stream);
#[derive(PartialEq, Debug)]
pub enum Raw {}
pub enum Validated {}

#[derive(Serialize, Debug, Clone, PartialEq, Eq)]
pub struct Config<State = Validated> {
#[serde(flatten)]
inner: Inner,
state: std::marker::PhantomData<State>,
}

impl<'de> serde::Deserialize<'de> for Config<Raw> {
fn deserialize<D>(deserializer: D) -> Result<Config<Raw>, D::Error>
where
D: serde::Deserializer<'de>,
{
let inner = Inner::deserialize(deserializer)?;
Ok(Config::new(inner))
}
}

impl<State> Config<State> {
fn new(inner: Inner) -> Self {
Config {
inner,
state: std::marker::PhantomData,
}
let mut edges = vec![];
for (stream_name, sink_name) in self.links.iter() {
let Some(stream) = streams.remove(stream_name.as_str()) else {
}
}

impl Config<Raw> {
pub fn validate(self) -> Result<Config<Validated>> {
use std::collections::HashSet;
let mut seen_sinks = HashSet::new();

for (stream_name, sink_name) in self.inner.links.iter() {
if !self.inner.plugs.contains_key(stream_name) {
return Err(anyhow!("No such plug: {stream_name}"));
};
let Some(sink) = sinks.remove(sink_name.as_str()) else {
return Err(anyhow!("No such plug or already used: {sink_name}"));
};
let edge = stream.forward(sink);
edges.push(edge);
}
if !self.inner.plugs.contains_key(sink_name) {
return Err(anyhow!("No such plug: {sink_name}"));
}

if seen_sinks.contains(sink_name) {
return Err(anyhow!("Sink {sink_name} used more than once"));
}
seen_sinks.insert(sink_name);
}
future::try_join_all(edges).await?;
Ok(())
Ok(Config::new(self.inner))
}
}

impl Config<Validated> {
pub fn plugs(&self) -> &HashMap<String, Url> {
&self.inner.plugs
}

pub fn links(&self) -> &HashMap<String, String> {
&self.inner.links
}
}

Expand All @@ -47,7 +81,7 @@ mod tests {
#[test]
fn test_de() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: seriald\n";
let expected = Config {
let inner = Inner {
plugs: HashMap::from_iter([
("tfsync".to_string(), Url::parse("exec:tfsync foo").unwrap()),
(
Expand All @@ -57,7 +91,33 @@ mod tests {
]),
links: HashMap::from_iter([("tfsync".to_string(), "seriald".to_string())]),
};
let expected = Config {
inner,
state: std::marker::PhantomData,
};
let actual = serde_yaml::from_str(yaml).unwrap();
assert_eq!(expected, actual);
actual.validate().unwrap();
}

#[test]
fn test_de_invalid_dest() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: serialdxxxx\n";
let actual: Config<Raw> = serde_yaml::from_str(yaml).unwrap();
assert!(actual.validate().is_err());
}

#[test]
fn test_de_invalid_source() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsyncxxxx: seriald\n";
let actual: Config<Raw> = serde_yaml::from_str(yaml).unwrap();
assert!(actual.validate().is_err());
}

#[test]
fn test_de_duplicate_sink() {
let yaml = "plugs:\n tfsync: exec:tfsync foo\n seriald: ws://seriald.local/\nlinks:\n tfsync: seriald\n seriald: seriald\n";
let actual: Config<Raw> = serde_yaml::from_str(yaml).unwrap();
assert!(actual.validate().is_err());
}
}

0 comments on commit 80a13f1

Please sign in to comment.