Skip to content

Commit

Permalink
Implement new custom tag struct
Browse files Browse the repository at this point in the history
  • Loading branch information
FerrahWolfeh committed Aug 29, 2023
1 parent c2fd4e4 commit 34c6b3f
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 59 deletions.
12 changes: 8 additions & 4 deletions benches/post_filter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use ahash::AHashSet;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use ibdl_common::post::{rating::Rating, Post};
use ibdl_common::post::{
rating::Rating,
tags::{Tag, TagType},
Post,
};
use rand::{
distributions::{Alphanumeric, DistString},
seq::SliceRandom,
Expand Down Expand Up @@ -175,9 +179,9 @@ fn seed_data(num: u64) -> (Vec<Post>, AHashSet<String>) {

let ext = EXTENSIONS.choose(&mut rng).unwrap().to_string();

let tags: Vec<String> = TAGS
let tags = TAGS
.choose_multiple(&mut rng, rn)
.map(|t| t.to_string())
.map(|t| Tag::new(t, TagType::General))
.collect();

let rating = *RATINGS.choose(&mut rng).unwrap();
Expand Down Expand Up @@ -213,7 +217,7 @@ pub fn blacklist_filter(list: Vec<Post>, tags: &AHashSet<String>, safe: bool) ->

if !blacklist.is_empty() {
let secondary_sz = lst.len();
lst.retain(|c| !c.tags.iter().any(|s| blacklist.contains(s)));
lst.retain(|c| !c.tags.iter().any(|s| blacklist.contains(&s.tag())));

let bp = secondary_sz - lst.len();
removed += bp as u64;
Expand Down
5 changes: 3 additions & 2 deletions ibdl-common/src/post/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ use std::{cmp::Ordering, fmt::Debug, ops::Not};

use crate::ImageBoards;

use self::rating::Rating;
use self::{rating::Rating, tags::Tag};

pub mod error;
pub mod extension;
pub mod rating;
pub mod tags;

/// Special enum to simplify the selection of the output file name when downloading a [`Post`]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
Expand Down Expand Up @@ -83,7 +84,7 @@ pub struct Post {
/// Set of tags associated with the post.
///
/// Used to exclude posts according to a blacklist
pub tags: Vec<String>,
pub tags: Vec<Tag>,
}

impl Debug for Post {
Expand Down
44 changes: 44 additions & 0 deletions ibdl-common/src/post/tags.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Tag {
tag: String,
tag_type: TagType,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum TagType {
Author,
Copyright,
Character,
/// Exclusive to e621/926
Species,
General,
/// Exclusive to e621/926
Lore,
Meta,
}

impl Tag {
pub fn new(text: &str, tag_type: TagType) -> Self {
Self {
tag: text.to_string(),
tag_type,
}
}

pub fn tag(&self) -> String {
self.tag.clone()
}

pub fn tag_type(&self) -> TagType {
self.tag_type
}

pub fn is_prompt_tag(&self) -> bool {
match self.tag_type {
TagType::Author | TagType::Copyright | TagType::Lore | TagType::Meta => false,
TagType::Character | TagType::Species | TagType::General => true,
}
}
}
19 changes: 17 additions & 2 deletions ibdl-core/src/async_queue/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ use futures::StreamExt;
use ibdl_common::log::debug;
use ibdl_common::post::error::PostError;
use ibdl_common::post::rating::Rating;
use ibdl_common::post::tags::TagType;
use ibdl_common::post::{NameType, Post, PostQueue};
use ibdl_common::reqwest::Client;
use ibdl_common::tokio::spawn;
Expand Down Expand Up @@ -461,7 +462,14 @@ impl Queue {
}
};

let prompt = post.tags.join(", ");
let tag_list = Vec::from_iter(
post.tags
.iter()
.filter(|t| t.is_prompt_tag())
.map(|tag| tag.tag()),
);

let prompt = tag_list.join(", ");

let f1 = prompt.replace('_', " ");

Expand Down Expand Up @@ -546,7 +554,14 @@ impl Queue {
.open(output.join(format!("{}.txt", post.name(name_type))))
.await?;

let prompt = post.tags.join(", ");
let tag_list = Vec::from_iter(
post.tags
.iter()
.filter(|t| t.is_prompt_tag())
.map(|tag| tag.tag()),
);

let prompt = tag_list.join(", ");

let f1 = prompt.replace('_', " ");

Expand Down
18 changes: 16 additions & 2 deletions ibdl-core/src/queue/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,14 @@ impl Queue {
}
};

let prompt = post.tags.join(", ");
let tag_list = Vec::from_iter(
post.tags
.iter()
.filter(|t| t.is_prompt_tag())
.map(|tag| tag.tag()),
);

let prompt = tag_list.join(", ");

let f1 = prompt.replace('_', " ");
//let f2 = f1.replace('(', "\\(");
Expand Down Expand Up @@ -547,7 +554,14 @@ impl Queue {
.open(output.join(format!("{}.txt", post.name(name_type))))
.await?;

let prompt = post.tags.join(", ");
let tag_list = Vec::from_iter(
post.tags
.iter()
.filter(|t| t.is_prompt_tag())
.map(|tag| tag.tag()),
);

let prompt = tag_list.join(", ");

let f1 = prompt.replace('_', " ");
//let f2 = f1.replace('(', "\\(");
Expand Down
5 changes: 3 additions & 2 deletions ibdl-extractors/src/blacklist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use ibdl_common::directories::ProjectDirs;
use ibdl_common::log::debug;
use ibdl_common::post::extension::Extension;
use ibdl_common::post::rating::Rating;
use ibdl_common::post::tags::{Tag, TagType};
use ibdl_common::post::Post;
use ibdl_common::serde::{self, Deserialize, Serialize};
use ibdl_common::tokio::fs::{create_dir_all, read_to_string, File};
Expand Down Expand Up @@ -208,12 +209,12 @@ impl BlacklistFilter {
if !self.gbl_tags.is_empty() {
debug!("Removing posts with tags {:?}", self.gbl_tags);

original_list.retain(|c| !c.tags.iter().any(|s| self.gbl_tags.contains(s)));
original_list.retain(|c| !c.tags.iter().any(|s| self.gbl_tags.contains(&s.tag())));
}
if self.ignore_animated {
original_list.retain(|post| {
let ext = post.extension.as_str();
!(post.tags.contains(&String::from("animated")) || ve.contains(ext))
!(post.tags.contains(&Tag::new("animated", TagType::Meta)) || ve.contains(ext))
})
}
let bp = fsize - original_list.len();
Expand Down
3 changes: 1 addition & 2 deletions ibdl-extractors/src/websites/danbooru/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ impl Extractor for DanbooruExtractor {
let batch = parsed_json.into_iter().filter(|c| c.file_url.is_some());

let mapper_iter = batch.map(|c| {
let tags = c.tag_string.unwrap();
let tag_list = Vec::from_iter(tags.split(' ').map(|tag| tag.to_string()));
let tag_list = c.map_tags();

let rt = c.rating.unwrap();
let rating = if rt == "s" {
Expand Down
42 changes: 40 additions & 2 deletions ibdl-extractors/src/websites/danbooru/models.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,50 @@
use ibdl_common::serde::{self, Deserialize, Serialize};
use ibdl_common::{
post::tags::{Tag, TagType},
serde::{self, Deserialize, Serialize},
};

#[derive(Serialize, Deserialize, Debug)]
#[serde(crate = "self::serde")]
pub struct DanbooruPost {
pub id: Option<u64>,
pub md5: Option<String>,
pub file_url: Option<String>,
pub tag_string: Option<String>,
pub tag_string_general: Option<String>,
pub tag_string_character: Option<String>,
pub tag_string_copyright: Option<String>,
pub tag_string_artist: Option<String>,
pub tag_string_meta: Option<String>,
pub file_ext: Option<String>,
pub rating: Option<String>,
}

impl DanbooruPost {
pub fn map_tags(&self) -> Vec<Tag> {
let mut tags = Vec::with_capacity(64);
if let Some(tagstr) = &self.tag_string_artist {
tags.extend(tagstr.split(' ').map(|tag| Tag::new(tag, TagType::Author)))
}
if let Some(tagstr) = &self.tag_string_copyright {
tags.extend(
tagstr
.split(' ')
.map(|tag| Tag::new(tag, TagType::Copyright)),
)
}
if let Some(tagstr) = &self.tag_string_character {
tags.extend(
tagstr
.split(' ')
.map(|tag| Tag::new(tag, TagType::Character)),
)
}
if let Some(tagstr) = &self.tag_string_general {
tags.extend(tagstr.split(' ').map(|tag| Tag::new(tag, TagType::General)))
}
if let Some(tagstr) = &self.tag_string_meta {
tags.extend(tagstr.split(' ').map(|tag| Tag::new(tag, TagType::Meta)))
}

tags
}
}
40 changes: 1 addition & 39 deletions ibdl-extractors/src/websites/e621/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,45 +231,7 @@ impl Extractor for E621Extractor {
let mut post_list: Vec<Post> = Vec::with_capacity(post_iter.size_hint().0);

post_iter.for_each(|c| {
let tag_array = [
c.tags.artist.len(),
c.tags.character.len(),
c.tags.general.len(),
c.tags.copyright.len(),
c.tags.lore.len(),
c.tags.meta.len(),
c.tags.species.len(),
];

let chunks = tag_array.chunks_exact(4);
let remainder = chunks.remainder();

let sum = chunks.fold([0usize; 4], |mut acc, chunk| {
let chunk: [usize; 4] = chunk.try_into().unwrap();
for i in 0..4 {
acc[i] += chunk[i];
}
acc
});

let remainder: usize = remainder.iter().sum();

let mut reduced: usize = 0;
for i in sum {
reduced += i;
}
let full_size = reduced + remainder;

//let full_size = tag_array.iter().sum();

let mut tag_list = Vec::with_capacity(full_size);
tag_list.append(&mut c.tags.character);
tag_list.append(&mut c.tags.artist);
tag_list.append(&mut c.tags.general);
tag_list.append(&mut c.tags.copyright);
tag_list.append(&mut c.tags.lore);
tag_list.append(&mut c.tags.meta);
tag_list.append(&mut c.tags.species);
let tag_list = c.tags.map_tags();

let unit = Post {
id: c.id.unwrap(),
Expand Down
28 changes: 27 additions & 1 deletion ibdl-extractors/src/websites/e621/models.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use ibdl_common::serde::{self, Deserialize, Serialize};
use ibdl_common::{
post::tags::{Tag, TagType},
serde::{self, Deserialize, Serialize},
};

#[derive(Serialize, Deserialize, Debug)]
#[serde(crate = "self::serde")]
Expand Down Expand Up @@ -44,3 +47,26 @@ pub struct Tags {
pub lore: Vec<String>,
pub meta: Vec<String>,
}

impl Tags {
pub fn map_tags(&self) -> Vec<Tag> {
let mut tag_list = Vec::with_capacity(64);
tag_list.extend(self.general.iter().map(|t| Tag::new(t, TagType::General)));
tag_list.extend(self.species.iter().map(|t| Tag::new(t, TagType::Species)));
tag_list.extend(
self.character
.iter()
.map(|t| Tag::new(t, TagType::Character)),
);
tag_list.extend(
self.copyright
.iter()
.map(|t| Tag::new(t, TagType::Copyright)),
);
tag_list.extend(self.artist.iter().map(|t| Tag::new(t, TagType::Author)));
tag_list.extend(self.lore.iter().map(|t| Tag::new(t, TagType::Lore)));
tag_list.extend(self.meta.iter().map(|t| Tag::new(t, TagType::Meta)));

tag_list
}
}
5 changes: 3 additions & 2 deletions ibdl-extractors/src/websites/gelbooru/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use async_trait::async_trait;
use ibdl_common::post::extension::Extension;
use ibdl_common::post::tags::{Tag, TagType};
use ibdl_common::reqwest::Client;
use ibdl_common::serde_json::{self, Value};
use ibdl_common::tokio::time::{sleep, Instant};
Expand Down Expand Up @@ -278,7 +279,7 @@ impl GelbooruExtractor {
let mut tags = Vec::with_capacity(tag_iter.size_hint().0);

tag_iter.for_each(|f| {
tags.push(f.to_string());
tags.push(Tag::new(f, TagType::General));
});

let rating = Rating::from_rating_str(f["rating"].as_str().unwrap());
Expand Down Expand Up @@ -334,7 +335,7 @@ impl GelbooruExtractor {
let mut tags = Vec::with_capacity(tag_iter.size_hint().0);

tag_iter.for_each(|i| {
tags.push(i.to_string());
tags.push(Tag::new(i, TagType::General));
});

let extension = extract_ext_from_url!(url);
Expand Down
3 changes: 2 additions & 1 deletion ibdl-extractors/src/websites/moebooru/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Post extractor for `https://konachan.com` and other Moebooru imageboards
use async_trait::async_trait;
use ibdl_common::post::extension::Extension;
use ibdl_common::post::tags::{Tag, TagType};
use ibdl_common::reqwest::Client;
use ibdl_common::{
client, extract_ext_from_url, join_tags,
Expand Down Expand Up @@ -220,7 +221,7 @@ impl Extractor for MoebooruExtractor {
let ext = extract_ext_from_url!(url);

tag_iter.for_each(|i| {
tags.push(i.to_string());
tags.push(Tag::new(i, TagType::General));
});

let unit = Post {
Expand Down

0 comments on commit 34c6b3f

Please sign in to comment.