Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev mv code from modules to functional #10420

Open
wants to merge 52 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fa04b47
add jvp interface and test
lihuizhao Jan 2, 2024
57ec099
add jvp interface and test
lihuizhao Jan 2, 2024
256d75a
add annotation information
lihuizhao Jan 2, 2024
cc7b43d
add jacobian
lihuizhao Jan 4, 2024
4f4c1d5
add vhp
lihuizhao Jan 4, 2024
49dfb97
add vhp test
lihuizhao Jan 4, 2024
f26d539
add hvp
lihuizhao Jan 4, 2024
65d90ce
add hvp test and fix vhp test
lihuizhao Jan 4, 2024
c6b8298
fix jacobian interface when the condition is strategy=reverse-mode an…
lihuizhao Jan 4, 2024
b5bac0e
fix jacobian interface when the condition is strategy=reverse-mode an…
lihuizhao Jan 5, 2024
5906339
add jacobian tests and hessian tests
lihuizhao Jan 10, 2024
e69a37c
format code
lihuizhao Jan 10, 2024
4bbb368
code format
lihuizhao Jan 10, 2024
030dc61
change jvp doc
lihuizhao Jan 10, 2024
6b8db1f
add jvp/jacobian/hessian/vhp/hvp to autograd.rst
lihuizhao Jan 10, 2024
233d423
add annotation
lihuizhao Jan 10, 2024
f54c58b
remove _batched_autograd_grad() function
lihuizhao Jan 12, 2024
81fe941
format
lihuizhao Jan 15, 2024
6efac49
Merge branch 'master' into dev_add_jvp_jacobian_hessian
lihuizhao Jan 15, 2024
6ab73d6
fix
lihuizhao Jan 16, 2024
ab874fb
Remove extra code
lihuizhao Jan 16, 2024
3e3b457
move interpolate function code from modules to functional
lihuizhao Jan 24, 2024
c3d03ac
move affine_grid code from modules to functional
lihuizhao Jan 24, 2024
c14a50b
Modify comment
lihuizhao Jan 24, 2024
f069447
format
lihuizhao Jan 24, 2024
325b526
move grid_sample code from modules to functional
lihuizhao Jan 24, 2024
46281cf
move sparse_softmax_cross_entropy code from modules to functional
lihuizhao Jan 24, 2024
28e58a4
move layer_norm code from modules to functional
lihuizhao Jan 24, 2024
7583963
move embedding code from modules to functional
lihuizhao Jan 24, 2024
eb91311
move linear code from modules to functional
lihuizhao Jan 24, 2024
74abde3
move relu6 code from modules to functional
lihuizhao Jan 24, 2024
fbb4164
move upsample code from modules to functional
lihuizhao Jan 24, 2024
3b1c519
pull code from github
lihuizhao Jan 25, 2024
3e51963
pull code from github
lihuizhao Jan 25, 2024
831b1f7
move group_norm code from modules to functional
lihuizhao Jan 25, 2024
a4ec283
format
lihuizhao Jan 25, 2024
ef1ee32
Remove useless code
lihuizhao Jan 25, 2024
ea51cc8
Modify comment document
lihuizhao Jan 25, 2024
377c0e9
Modify comment document
lihuizhao Jan 25, 2024
f2490c9
Modify comment document
lihuizhao Jan 25, 2024
0c109d9
merge upsample and interpolate
lihuizhao Jan 26, 2024
38e7b31
Compact code
lihuizhao Jan 26, 2024
6f0e8a9
add upsample func test code
lihuizhao Jan 29, 2024
c40ec4e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
lihuizhao Jan 29, 2024
56f9fdc
add interpolate module test code
lihuizhao Jan 29, 2024
e13b1dd
add AffineGrid class test
lihuizhao Jan 29, 2024
ba92dce
improve LayerNorm class code
lihuizhao Jan 29, 2024
07a32fc
add LayerNorm class test code
lihuizhao Jan 30, 2024
6368658
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
lihuizhao Jan 30, 2024
a9ea436
Merge branch 'master' into dev_mv_code_from_modules_to_functional
marigoold Feb 2, 2024
406e46a
Merge remote-tracking branch 'upstram/dev_mv_code_from_modules_to_fun…
lihuizhao Feb 5, 2024
bd3c34d
Add 'Type hints', 'documentation', 'example code'
lihuizhao Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/oneflow/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@
GRU,
)

from oneflow.nn.modules.interpolate import Interpolate
from oneflow.nn.modules.affine_grid import AffineGrid
from oneflow.nn.modules.grid_sample import GridSample
from oneflow.nn.modules.sparse_softmax_cross_entropy import SparseSoftmaxCrossEntropy

from oneflow.nn.qat.conv import QatConv1d, QatConv2d, QatConv3d


Expand Down
17 changes: 8 additions & 9 deletions python/oneflow/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.nn.modules.interpolate import interpolate
from oneflow.nn.modules.affine_grid import affine_grid
from oneflow.nn.modules.grid_sample import grid_sample
from oneflow.nn.modules.sparse_softmax_cross_entropy import sparse_softmax_cross_entropy
from .interpolate import interpolate, upsample
from .affine_grid import affine_grid
from .grid_sample import grid_sample
from .sparse_softmax_cross_entropy import sparse_softmax_cross_entropy
from oneflow._C import conv1d
from oneflow._C import conv2d
from oneflow._C import conv3d
Expand Down Expand Up @@ -65,7 +65,7 @@
from oneflow._C import threshold
from oneflow._C import silu
from oneflow._C import mish
from oneflow.nn.modules.normalization import layer_norm, group_norm
from .normalization import layer_norm, group_norm
from oneflow._C import dropout, dropout1d, dropout2d, dropout3d
from oneflow._C import smooth_l1_loss
from .pad import pad
Expand All @@ -82,10 +82,9 @@
from oneflow._C import (
binary_cross_entropy_with_logits_loss as binary_cross_entropy_with_logits,
)
from oneflow.nn.modules.sparse import embedding
from oneflow.nn.modules.linear import linear
from oneflow.nn.modules.activation import relu6
from oneflow.nn.modules.upsampling import Upsample as upsample
from .sparse import embedding
from .linear import linear
from .activation import relu6
from oneflow._C import unfold
from oneflow._C import fold
from .deform_conv import deform_conv2d
Expand Down
36 changes: 36 additions & 0 deletions python/oneflow/nn/functional/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings
import oneflow as flow
from oneflow.framework.tensor import Tensor


def relu6(input: Tensor, inplace=False) -> Tensor:
r"""relu6(input: Tensor, inplace=False) -> Tensor

Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`.

See :class:`~oneflow.nn.ReLU6` for more details.
"""
if inplace:
warnings.warn("relu6 do not support inplace now")
return flow._C.hardtanh(input, min_val=0.0, max_val=6.0)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
75 changes: 75 additions & 0 deletions python/oneflow/nn/functional/affine_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List

import oneflow as flow
from oneflow.framework.tensor import Tensor


def affine_grid(theta: Tensor, size: List[int], align_corners: bool = False) -> Tensor:
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/1.10/generated/torch.nn.functional.affine_grid.html.

Generates a 2D or 3D flow field (sampling grid), given a batch of
affine matrices :attr:`theta`.

.. note::
This function is often used in conjunction with :func:`grid_sample`
to build `Spatial Transformer Networks`_ .

Args:
theta (Tensor): input batch of affine matrices with shape
(:math:`N, 2, 3`) for 2D or
(:math:`N, 3, 4`) for 3D
size (oneflow.Size): the target output image size.
(:math:`N, C, H, W` for 2D or
:math:`N, C, D, H, W` for 3D)
Example: oneflow.Size((32, 3, 24, 24))
align_corners (bool): if ``True``, consider ``-1`` and ``1``
to refer to the centers of the corner pixels rather than the image corners.
Refer to :func:`grid_sample` for a more complete description.
A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample`
with the same setting for this option.
Default: ``False``

Returns:
output (Tensor): output Tensor of size (:math:`N, H, W, 2`)

.. _`Spatial Transformer Networks`:
https://arxiv.org/abs/1506.02025

Examples::

>>> import oneflow as flow
>>> import numpy as np
>>> input = flow.tensor(np.arange(1., 7).reshape((1, 2, 3)), dtype=flow.float32)
>>> output = flow.nn.functional.affine_grid(input, flow.Size([1, 1, 2, 2]), align_corners=True)
>>> output
tensor([[[[ 0., -3.],
[ 2., 5.]],
<BLANKLINE>
[[ 4., 7.],
[ 6., 15.]]]], dtype=oneflow.float32)
"""
y = flow._C.affine_grid(theta, size=size, align_corners=align_corners)
return y


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
146 changes: 146 additions & 0 deletions python/oneflow/nn/functional/grid_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import oneflow as flow
from oneflow.framework.tensor import Tensor


def grid_sample(
input: Tensor,
grid: Tensor,
mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = False,
) -> Tensor:
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/1.10/generated/torch.nn.functional.grid_sample.html.

Given an :attr:`input` and a flow-field :attr:`grid`, computes the
``output`` using :attr:`input` values and pixel locations from :attr:`grid`.

Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are
supported.

In the spatial (4-D) case, for :attr:`input` with shape
:math:`(N, C, H_{in}, W_{in})` and :attr:`grid` with shape
:math:`(N, H_{out}, W_{out}, 2)`, the output will have shape
:math:`(N, C, H_{out}, W_{out})`.

For each output location ``output[n, :, h, w]``, the size-2 vector
``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,
which are used to interpolate the output value ``output[n, :, h, w]``.
In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the
``x``, ``y``, ``z`` pixel locations for interpolating
``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or
``bilinear`` interpolation method to sample the input pixels.

:attr:`grid` specifies the sampling pixel locations normalized by the
:attr:`input` spatial dimensions. Therefore, it should have most values in
the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the
left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the
right-bottom pixel of :attr:`input`.

If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding
outputs are handled as defined by :attr:`padding_mode`. Options are

* ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations,
* ``padding_mode="border"``: use border values for out-of-bound grid locations,
* ``padding_mode="reflection"``: use values at locations reflected by
the border for out-of-bound grid locations. For location far away
from the border, it will keep being reflected until becoming in bound,
e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1``
and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes
``x'' = -0.5``.

Note:
This function is often used in conjunction with :func:`affine_grid`
to build `Spatial Transformer Networks`_ .

Note:
NaN values in :attr:`grid` would be interpreted as ``-1``.

Args:
input (Tensor): input of shape :math:`(N, C, H_{in}, W_{in})` (4-D case)
or :math:`(N, C, D_{in}, H_{in}, W_{in})` (5-D case)
grid (Tensor): flow-field of shape :math:`(N, H_{out}, W_{out}, 2)` (4-D case)
or :math:`(N, D_{out}, H_{out}, W_{out}, 3)` (5-D case)
mode (str): interpolation mode to calculate output values
``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'``
Note: ``mode='bicubic'`` supports only 4-D input.
When ``mode='bilinear'`` and the input is 5-D, the interpolation mode
used internally will actually be trilinear. However, when the input is 4-D,
the interpolation mode will legitimately be bilinear.
padding_mode (str): padding mode for outside grid values
``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'``
align_corners (bool): Geometrically, we consider the pixels of the
input as squares rather than points.
If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring
to the center points of the input's corner pixels. If set to ``False``, they
are instead considered as referring to the corner points of the input's corner
pixels, making the sampling more resolution agnostic.
This option parallels the ``align_corners`` option in
:func:`interpolate`, and so whichever option is used here
should also be used there to resize the input image before grid sampling.
Default: ``False``

Returns:
output (Tensor): output Tensor

.. _`Spatial Transformer Networks`:
https://arxiv.org/abs/1506.02025

.. note::
``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\\alpha=-0.75`.
The constant :math:`\\alpha` might be different from packages to packages.
For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively.
This algorithm may "overshoot" the range of values it's interpolating.
For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255].
Clamp the results with :func: `flow.clamp` to ensure they are within the valid range.
.. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation
.. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51
.. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908

Examples::

>>> import oneflow as flow
>>> import numpy as np
>>> input = flow.tensor(np.arange(1., 11).reshape((1, 1, 2, 5)), dtype=flow.float32)
>>> np_grid = np.array(
... [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
... [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]]
... ).reshape(1, 2, 5, 2)
>>> grid = flow.tensor(np_grid, dtype=flow.float32)
>>> output = flow.nn.functional.grid_sample(input, grid, mode='nearest', padding_mode='zeros',
... align_corners=True)
>>> output
tensor([[[[0., 8., 5., 7., 9.],
[1., 8., 5., 8., 0.]]]], dtype=oneflow.float32)
"""
y = flow._C.grid_sample(
input,
grid,
interpolation_mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
return y


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
Loading
Loading