-
Notifications
You must be signed in to change notification settings - Fork 2
/
CDSC.py
36 lines (27 loc) · 1.12 KB
/
CDSC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
@FileName: CDSC.py
@Author: Chenghong Xiao
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import ComplexConv1d
class CDSC1d(nn.Module):
"""
The Complex-Valued Depthwise Separable Convolution (CDSC)
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
""" Initialize a CDSC.
Description of the Structure:
A CDSC factorizes one regular convolution into one depthwise convolution (DWC) in the spatial dimension
and one pointwise convolution (PWC) in the channel dimension.
We perform the real-valued DWC in the spatial dimension and the complex-valued PWC in the channel dimension.
"""
super(CDSC1d, self).__init__()
self.DWC = nn.Conv1d(in_channels, in_channels, kernel_size, stride, padding, dilation,
groups=in_channels, bias=bias)
self.PWC = ComplexConv1d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.DWC(x)
x = self.PWC(x)
return x