-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds the capability to run a Python script or executable during evaluation, in order to get external metrics.
- Loading branch information
Showing
22 changed files
with
1,121 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import lbann | ||
import numpy as np | ||
import os | ||
import os.path | ||
import sys | ||
|
||
current_file = os.path.realpath(__file__) | ||
current_dir = os.path.dirname(current_file) | ||
module_name = os.path.splitext(os.path.basename(current_file))[0] | ||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) | ||
import test_util | ||
|
||
|
||
# ============================================== | ||
# Functionality for Python metrics | ||
# ============================================== | ||
# Note: The Python metric class imports this file as a module and calls | ||
# the functions below to return a value. | ||
def evaluate(experiment_path, rank): | ||
if not experiment_path or not isinstance(experiment_path, str): | ||
return -1.0 | ||
if experiment_path != 'trainer0/model0': | ||
return -2.0 | ||
if not isinstance(rank, int) or rank < 0: | ||
return -3.0 | ||
|
||
# Successful result | ||
return 1.0 | ||
|
||
|
||
def return_weight_value_plus_one(experiment_path, rank): | ||
trainer, model = experiment_path.split('/') | ||
workdir = os.path.join(trainer, 'sgd.shared.training_begin.epoch.0.step.0', | ||
model) | ||
weights = np.fromfile(os.path.join(workdir, 'w.txt'), sep=' ') | ||
return weights.item() + 1.0 | ||
|
||
|
||
# ============================================== | ||
# Tests | ||
# ============================================== | ||
|
||
|
||
@test_util.lbann_test() | ||
def test_metric(): | ||
# Prepare reference output | ||
np.random.seed(20240105) | ||
x = np.random.rand(2, 2).astype(np.float32) | ||
ref = x + 1 | ||
|
||
tester = test_util.ModelTester() | ||
|
||
x = tester.inputs(x) | ||
reference = tester.make_reference(ref) | ||
|
||
# Test layer | ||
y = lbann.AddConstant(x, constant=1) | ||
|
||
tester.extra_metrics.append( | ||
lbann.PythonMetric(name='pymetric', | ||
module=module_name, | ||
module_dir=current_dir, | ||
function='evaluate')) | ||
tester.extra_callbacks.append( | ||
lbann.CallbackCheckMetric(metric='pymetric', | ||
lower_bound=1.0, | ||
upper_bound=1.0, | ||
error_on_failure=True, | ||
execution_modes='test')) | ||
|
||
# Set test loss | ||
tester.set_loss(lbann.MeanSquaredError(y, reference)) | ||
tester.set_check_gradients_tensor(lbann.Square(y)) | ||
return tester | ||
|
||
|
||
@test_util.lbann_test() | ||
def test_metric_with_callback(): | ||
# Prepare reference output | ||
np.random.seed(20240104) | ||
x = np.random.rand(2, 2).astype(np.float32) | ||
w = np.random.rand(1).astype(np.float32) | ||
ref = x + w | ||
|
||
tester = test_util.ModelTester() | ||
|
||
x = tester.inputs(x) | ||
reference = tester.make_reference(ref) | ||
|
||
# Test layer | ||
wlayer = lbann.WeightsLayer( | ||
weights=lbann.Weights( | ||
name='w', | ||
initializer=lbann.ValueInitializer(values=[w[0]]), | ||
), | ||
dims=[1], | ||
) | ||
wbcast = lbann.Tessellate(wlayer, dims=[2]) | ||
y = lbann.Add(x, wbcast) | ||
|
||
tester.extra_metrics.append( | ||
lbann.PythonMetric(name='pymetric', | ||
module=module_name, | ||
module_dir=current_dir, | ||
function='return_weight_value_plus_one')) | ||
|
||
# First add the dump weights callback, then check metric | ||
tester.extra_callbacks.extend([ | ||
lbann.CallbackDumpWeights(directory='.', epoch_interval=1), | ||
lbann.CallbackCheckMetric(metric='pymetric', | ||
lower_bound=w[0] + 1 - 1e-8, | ||
upper_bound=w[0] + 1 + 1e-8, | ||
error_on_failure=True, | ||
execution_modes='test'), | ||
]) | ||
|
||
# Set test loss | ||
tester.set_loss(lbann.MeanSquaredError(y, reference)) | ||
tester.set_check_gradients_tensor(lbann.Square(y)) | ||
return tester |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
//////////////////////////////////////////////////////////////////////////////// | ||
// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. | ||
// Produced at the Lawrence Livermore National Laboratory. | ||
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in | ||
// the CONTRIBUTORS file. <[email protected]> | ||
// | ||
// LLNL-CODE-697807. | ||
// All rights reserved. | ||
// | ||
// This file is part of LBANN: Livermore Big Artificial Neural Network | ||
// Toolkit. For details, see http://software.llnl.gov/LBANN or | ||
// https://github.com/LLNL/LBANN. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you | ||
// may not use this file except in compliance with the License. You may | ||
// obtain a copy of the License at: | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||
// implied. See the License for the specific language governing | ||
// permissions and limitations under the license. | ||
//////////////////////////////////////////////////////////////////////////////// | ||
|
||
#ifndef LBANN_METRIC_EXECUTABLE_METRIC_HPP | ||
#define LBANN_METRIC_EXECUTABLE_METRIC_HPP | ||
|
||
#include "lbann/metrics/metric.hpp" | ||
|
||
namespace lbann { | ||
|
||
/** @brief A metric that receives its value from parsing the output of a running | ||
* executable. | ||
* | ||
* This metric spawns an executable file with every evaluation and reads its | ||
* output to receive one value of type ``EvalType``. It expects the program | ||
* to have no other output to ``stdout``, and will be called with the following | ||
* command-line arguments: ``<filename> [other_args] <experiment directory>``. | ||
* Experiment directory is defined as the trainer name and model name, | ||
* separated by a slash (e.g., ``trainer0/model0``). The executable will be | ||
* run from the current working directory of the experiment folder itself. | ||
*/ | ||
class executable_metric : public metric | ||
{ | ||
|
||
public: | ||
executable_metric(lbann_comm* comm = nullptr, | ||
std::string name = "", | ||
std::string filename = "", | ||
std::string other_args = "") | ||
: metric(comm), m_name(name), m_filename(filename), m_other_args(other_args) | ||
{} | ||
executable_metric(const executable_metric& other) = default; | ||
executable_metric& operator=(const executable_metric& other) = default; | ||
virtual ~executable_metric() = default; | ||
executable_metric* copy() const override | ||
{ | ||
return new executable_metric(*this); | ||
} | ||
|
||
/** Return a string name for this metric. */ | ||
std::string name() const override; | ||
|
||
/** Archive for checkpoint and restart */ | ||
template <class Archive> | ||
void serialize(Archive& ar); | ||
|
||
/** Get list of pointers to layers. */ | ||
std::vector<ViewingLayerPtr> get_layer_pointers() const override; | ||
/** Set list of pointers to layers. */ | ||
void set_layer_pointers(std::vector<ViewingLayerPtr> layers) override; | ||
|
||
/** Save metric state to checkpoint. */ | ||
bool save_to_checkpoint_shared(persist& p) override; | ||
/** Load metric state from checkpoint. */ | ||
bool load_from_checkpoint_shared(persist& p) override; | ||
|
||
bool save_to_checkpoint_distributed(persist& p) override; | ||
bool load_from_checkpoint_distributed(persist& p) override; | ||
|
||
protected: | ||
void setup(model& m) override; | ||
EvalType evaluate(execution_mode mode, int mini_batch_size) override; | ||
|
||
private: | ||
/** Descriptive name for metric. */ | ||
std::string m_name; | ||
|
||
/** Path to executable to run. */ | ||
std::string m_filename; | ||
|
||
/** Arguments to prepend before experiment path. */ | ||
std::string m_other_args; | ||
|
||
/** Full command line to run. */ | ||
std::string m_cmd; | ||
}; | ||
|
||
} // namespace lbann | ||
|
||
#endif // LBANN_METRIC_EXECUTABLE_METRIC_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.