From 7ce6dec35b06fecbd567173d0e2581f2bfa29b7e Mon Sep 17 00:00:00 2001 From: Sebastian Riedel Date: Fri, 12 Jan 2024 17:37:21 +0100 Subject: [PATCH] Add ability to reimport training data classifications --- lib/Cavil/Command/learn.pm | 36 +++++++++++++++++++++++++---- t/command_learn.t | 47 ++++++++++++++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/lib/Cavil/Command/learn.pm b/lib/Cavil/Command/learn.pm index 742278593f..11f1d034ef 100644 --- a/lib/Cavil/Command/learn.pm +++ b/lib/Cavil/Command/learn.pm @@ -23,13 +23,41 @@ has description => 'Training data for machine learning'; has usage => sub ($self) { $self->extract_usage }; sub run ($self, @args) { - getopt \@args, 'e|export=s' => \my $export; - die 'Export directory is required' unless defined $export; + getopt \@args, + 'i|input=s' => \my $input, + 'o|output=s' => \my $output; + die 'Input or output directory is required' unless (defined $output || defined $input); my $app = $self->app; my $db = $app->pg->db; - my $root = path($export); + return _output($db, $output) if $output; + return _input($db, $input); +} + +sub _classify ($db, $name, $license) { + return 0 unless $name =~ /^(\w+).txt$/; + return $db->query( + 'UPDATE snippets SET license = ?, classified = true, approved = true WHERE hash = ? AND approved = false', + $license, $1)->rows; +} + +sub _input ($db, $input) { + my $root = path($input); + my $good = $root->child('good'); + my $bad = $root->child('bad'); + + return unless -d $good && -d $bad; + + my $count = 0; + $count += _classify($db, $_->basename, 1) for $good->list->each; + $count += _classify($db, $_->basename, 0) for $bad->list->each; + + say "Imported $count snippet classifications"; +} + +sub _output ($db, $output) { + my $root = path($output); my $good = $root->child('good')->make_path; my $bad = $root->child('bad')->make_path; @@ -72,7 +100,7 @@ Cavil::Command::learn - Cavil learn command script/cavil learn -e ./input Options: - -e, --export Export snippets for training machine learning models + -o, --output Export snippets for training machine learning models -h, --help Show this summary of available options =cut diff --git a/t/command_learn.t b/t/command_learn.t index 91877fdc90..4421d625e9 100644 --- a/t/command_learn.t +++ b/t/command_learn.t @@ -39,7 +39,7 @@ subtest 'Empty database' => sub { { open my $handle, '>', \$buffer; local *STDOUT = $handle; - $app->start('learn', '-e', "$dir"); + $app->start('learn', '-o', "$dir"); } like $buffer, qr/Exported 0 snippets/, 'no snippets'; ok -e $dir->child('good'), 'directory exists'; @@ -54,14 +54,14 @@ subtest 'Snippets added' => sub { $db->query('UPDATE snippets SET license = false, approved = true WHERE id = 1'); $db->query('UPDATE snippets SET license = true, approved = true WHERE id = 2'); $db->query('UPDATE snippets SET license = true, approved = false WHERE id = 3'); + my $dir = $tmp->child('two'); - subtest 'Export snippets' => sub { - my $dir = $tmp->child('two'); + subtest 'Output snippets' => sub { my $buffer = ''; { open my $handle, '>', \$buffer; local *STDOUT = $handle; - $app->start('learn', '-e', "$dir"); + $app->start('learn', '-o', "$dir"); } like $buffer, qr/Exporting snippet 1/, 'first snippet'; like $buffer, qr/Exporting snippet 2/, 'second snippet'; @@ -74,6 +74,45 @@ subtest 'Snippets added' => sub { is $good->size, 1, 'one file'; like $good->first->slurp, qr/Copyright Holder/, 'right content'; }; + + $db->query('UPDATE snippets SET license = true, approved = false WHERE id = 1'); + $db->query('UPDATE snippets SET license = false, approved = false WHERE id = 2'); + $dir->child('good', 'doesnotexist.txt')->spew('Whatever'); + $dir->child('bad', 'doesnotexist.txt')->spew('Whatever'); + + subtest 'Input snippets' => sub { + my $buffer = ''; + { + open my $handle, '>', \$buffer; + local *STDOUT = $handle; + $app->start('learn', '-i', "$dir"); + } + like $buffer, qr/Imported 2 snippet classifications/, 'two snippets imported'; + + my $first = $db->select('snippets', '*', {id => 1})->hash; + is $first->{license}, 0, 'is not a license'; + is $first->{classified}, 1, 'is classified'; + is $first->{approved}, 1, 'is approved'; + + my $second = $db->select('snippets', '*', {id => 2})->hash; + is $second->{license}, 1, 'is license'; + is $second->{approved}, 1, 'is approved'; + is $second->{classified}, 1, 'is classified'; + + my $third = $db->select('snippets', '*', {id => 3})->hash; + is $third->{approved}, 0, 'not approved'; + is $third->{classified}, 0, 'not classified'; + }; + + subtest 'Input snippets (repeat does nothing)' => sub { + my $buffer = ''; + { + open my $handle, '>', \$buffer; + local *STDOUT = $handle; + $app->start('learn', '-i', "$dir"); + } + like $buffer, qr/Imported 0 snippet classifications/, 'no snippets imported'; + }; }; done_testing();