From d03deedc2e2d58c6844ccdf19e88345e1c3d4760 Mon Sep 17 00:00:00 2001 From: youssefmecky96 Date: Tue, 31 Oct 2023 15:54:16 +0100 Subject: [PATCH] added pytorch forecasting wrapper to make it easier --- configs/tasks/BinaryClassification.gin | 2 +- configs/tasks/Regression.gin | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/tasks/BinaryClassification.gin b/configs/tasks/BinaryClassification.gin index 492a12eb..f3a7790d 100644 --- a/configs/tasks/BinaryClassification.gin +++ b/configs/tasks/BinaryClassification.gin @@ -16,7 +16,7 @@ train_common.ram_cache = True # DEEP LEARNING DLPredictionWrapper.loss = @cross_entropy - +DLPredictionPytorchForecastingWrapper.loss= @cross_entropy # SELECTING PREPROCESSOR preprocess.preprocessor = @base_classification_preprocessor preprocess.vars = %vars diff --git a/configs/tasks/Regression.gin b/configs/tasks/Regression.gin index 5cf3f8d9..6b6ac607 100644 --- a/configs/tasks/Regression.gin +++ b/configs/tasks/Regression.gin @@ -16,6 +16,7 @@ train_common.ram_cache = True # LOSS FUNCTION DLPredictionWrapper.loss = @mse_loss +DLPredictionPytorchForecastingWrapper.loss = @mse_loss MLWrapper.loss = @mean_squared_error # SELECTING PREPROCESSOR