Skip to content

Commit

Permalink
Remove self.device to prevent import requirement of torch for predict…
Browse files Browse the repository at this point in the history
…ion (#617)
  • Loading branch information
darrylong authored May 6, 2024
1 parent ae7ba86 commit 5de0a8e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
3 changes: 1 addition & 2 deletions cornac/models/lightgcn/lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def forward(self, g, feat_dict):


class Model(nn.Module):
def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
def __init__(self, g, in_size, num_layers, lambda_reg):
super(Model, self).__init__()
self.norm_dict = dict()
self.lambda_reg = lambda_reg
self.device = device

self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])

Expand Down
6 changes: 3 additions & 3 deletions cornac/models/lightgcn/recom_lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,21 @@ def fit(self, train_set, val_set=None):
from .lightgcn import Model
from .lightgcn import construct_graph

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.seed is not None:
torch.manual_seed(self.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)

graph = construct_graph(train_set, self.total_users, self.total_items).to(
self.device
device
)
model = Model(
graph,
self.emb_size,
self.num_layers,
self.lambda_reg,
).to(self.device)
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)

Expand Down

0 comments on commit 5de0a8e

Please sign in to comment.