From 5de0a8e9bf363827508c0b98a9f6f1590350357b Mon Sep 17 00:00:00 2001 From: darrylong Date: Tue, 7 May 2024 06:04:07 +0800 Subject: [PATCH] Remove self.device to prevent import requirement of torch for prediction (#617) --- cornac/models/lightgcn/lightgcn.py | 3 +-- cornac/models/lightgcn/recom_lightgcn.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cornac/models/lightgcn/lightgcn.py b/cornac/models/lightgcn/lightgcn.py index fbe37a460..a88342ac0 100644 --- a/cornac/models/lightgcn/lightgcn.py +++ b/cornac/models/lightgcn/lightgcn.py @@ -62,11 +62,10 @@ def forward(self, g, feat_dict): class Model(nn.Module): - def __init__(self, g, in_size, num_layers, lambda_reg, device=None): + def __init__(self, g, in_size, num_layers, lambda_reg): super(Model, self).__init__() self.norm_dict = dict() self.lambda_reg = lambda_reg - self.device = device self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)]) diff --git a/cornac/models/lightgcn/recom_lightgcn.py b/cornac/models/lightgcn/recom_lightgcn.py index e4967652e..eaeb650a9 100644 --- a/cornac/models/lightgcn/recom_lightgcn.py +++ b/cornac/models/lightgcn/recom_lightgcn.py @@ -124,21 +124,21 @@ def fit(self, train_set, val_set=None): from .lightgcn import Model from .lightgcn import construct_graph - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.seed is not None: torch.manual_seed(self.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.seed) graph = construct_graph(train_set, self.total_users, self.total_items).to( - self.device + device ) model = Model( graph, self.emb_size, self.num_layers, self.lambda_reg, - ).to(self.device) + ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)