Skip to content

Commit

Permalink
add raise error in connectivity
Browse files Browse the repository at this point in the history
  • Loading branch information
spirosChv committed Nov 29, 2024
1 parent 04af1b4 commit 3b149f9
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions codes_/receptive_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,14 +494,27 @@ def connectivity(inputs, outputs):
Returns
-------
numpy.ndarray
connectivity_matrix : numpy.ndarray [int]
The connectivity matrix between inputs and outputs.
Raises
------
ValueError
If inputs or outputs are non-positive, or if inputs are not divisible by outputs.
"""
mask = np.zeros((inputs, outputs))
if outputs <= 0:
raise ValueError("Number of outputs must be greater than zero.")
if inputs <= 0:
raise ValueError("Number of inputs must be greater than zero.")
if inputs % outputs != 0:
raise ValueError("Inputs must be divisible by outputs without a remainder.")

connectivity_matrix = np.zeros((inputs, outputs), dtype=int)
in_per_out = inputs // outputs # nodes per node
# Fill the connectivity matrix
for j in range(outputs):
mask[in_per_out * j:in_per_out * (j + 1), j] = 1

return (mask.astype('int'))
start_index = in_per_out * j
end_index = start_index + in_per_out
connectivity_matrix[start_index:end_index, j] = 1
return connectivity_matrix

0 comments on commit 3b149f9

Please sign in to comment.