Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error while implementing Subgraphx #253

Open
xristos19353 opened this issue Oct 28, 2024 · 0 comments
Open

error while implementing Subgraphx #253

xristos19353 opened this issue Oct 28, 2024 · 0 comments

Comments

@xristos19353
Copy link

Hi all,

I installed DIG and tried to explain the graph model that I have built using PyG as shown below:

 `class GCN(torch.nn.Module):
  def __init__(self, in_channels, hidden_channels):
      super(GCN, self).__init__()
      self.conv1 = GCNConv(in_channels, hidden_channels)
      self.conv2 = GCNConv(hidden_channels, hidden_channels)
      self.fc = torch.nn.Linear(hidden_channels, 1)  # Output 1 probability for binary classification

  def forward(self, x, edge_index, batch=None):
      x = self.conv1(x, edge_index)
      x = F.relu(x)
      x = F.dropout(x, p=0.5, training=self.training)
      x = self.conv2(x, edge_index)
    
    if batch is not None:
        # Global mean pooling to aggregate node features
        x = global_mean_pool(x, batch)
    else:
        # If batch is not provided, we assume it's a single graph, so aggregate all nodes
        x = x.mean(dim=0, keepdim=True)
    x = self.fc(x)  # Final classification layer
    return x.squeeze()  # Return logits for binary classification

model = GCN(in_channels=10, hidden_channels=16)  # No need for out_channels, output is single probability
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss()  # Use BCE loss with logits for binary classification`

When I call the SubgraphX explainer with the following commands:

`from dig.xgraph.method import SubgraphX
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
explainer = SubgraphX(model, num_classes=1, device=device, explain_graph=True,
                        reward_method='gnn_score')
explainer(val_data[0].x,val_data[0].edge_index)`

I get the error, (TypeError: GCN.forward() got an unexpected keyword argument 'data'), during the computation of the scores.
Do you have any idea how to resolve this error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant