Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add FusedApplyRotaryEmbGradKernel #10517

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
144 changes: 144 additions & 0 deletions oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct FusedApplyRotaryEmbCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
std::string x_layout{};
std::string output_layout{};
std::string mode{};
int64_t tensor_index{};
int64_t k_size{};
float base = 0.0f;
int64_t rotary_size{};
};

class FusedApplyRotaryEmb : public OpExprGradFunction<FusedApplyRotaryEmbCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> FusedApplyRotaryEmb::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> FusedApplyRotaryEmb::Capture(FusedApplyRotaryEmbCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_OR_RETURN((inputs.size() >= 1) && (inputs.size() <= 4))
<< Error::RuntimeError() << "the inputs size of fusedapplyrotaryembgrad\
should between 1 and 4";
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(inputs.at(0));
if (inputs.size() == 2) // position_ids
ctx->SaveTensorForBackward(inputs.at(1));

if (inputs.size() == 3) { // cos, sin
ctx->SaveTensorForBackward(inputs.at(1));
ctx->SaveTensorForBackward(inputs.at(2));
}

if (inputs.size() == 4) { // cos, sin, position_ids;
ctx->SaveTensorForBackward(inputs.at(1));
ctx->SaveTensorForBackward(inputs.at(2));
ctx->SaveTensorForBackward(inputs.at(3));
}

ctx->x_layout = JUST(composed_attrs.GetAttr<std::string>("x_layout"));
ctx->output_layout = JUST(composed_attrs.GetAttr<std::string>("output_layout"));
ctx->mode = JUST(composed_attrs.GetAttr<std::string>("mode"));
ctx->tensor_index = JUST(composed_attrs.GetAttr<int64_t>("tensor_index"));
ctx->k_size = JUST(composed_attrs.GetAttr<int64_t>("k_size"));
ctx->base = JUST(composed_attrs.GetAttr<float>("base"));
ctx->rotary_size = JUST(composed_attrs.GetAttr<int64_t>("rotary_size"));

return Maybe<void>::Ok();
}

Maybe<void> FusedApplyRotaryEmb::Apply(const FusedApplyRotaryEmbCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1)
<< Error::RuntimeError() << "fusedapplyrotaryembgrad outgrad size should be 1";
const auto& saved_tensors = ctx->SavedTensors();

CHECK_OR_RETURN((saved_tensors.size() >= 1) && (saved_tensors.size() <= 4))
<< Error::RuntimeError() << "the saved_tensors of fusedapplyrotaryembgrad\
should between 1 and 4";

if (ctx->requires_grad) {
if (saved_tensors.size() == 1) { // x
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(
x, out_grads.at(0), NullOpt /*cos*/, NullOpt /*sin*/, NullOpt /*position_ids*/,
ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base,
ctx->rotary_size));
}

if (saved_tensors.size() == 2) { // x, position_ids
const auto& x = ctx->SavedTensors().at(0);
const auto& position_ids = ctx->SavedTensors().at(1);
in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(
x, out_grads.at(0), NullOpt /*cos*/, NullOpt /*sin*/, position_ids, ctx->x_layout,
ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base,
ctx->rotary_size));
}

if (saved_tensors.size() == 3) { // x, cos, sin, position_ids
const auto& x = ctx->SavedTensors().at(0);
const auto& cos = ctx->SavedTensors().at(1);
const auto& sin = ctx->SavedTensors().at(2);

in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(
x, out_grads.at(0), cos, sin, NullOpt /*position_ids*/, ctx->x_layout, ctx->output_layout,
ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size));
}

if (saved_tensors.size() == 4) {
const auto& x = ctx->SavedTensors().at(0);
const auto& cos = ctx->SavedTensors().at(1);
const auto& sin = ctx->SavedTensors().at(2);
const auto& position_ids = ctx->SavedTensors().at(3);
in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(
x, out_grads.at(0), cos, sin, position_ids, ctx->x_layout, ctx->output_layout, ctx->mode,
ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size));
}
}

return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("fused_apply_rotary_emb", FusedApplyRotaryEmb);

} // namespace one
} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,10 @@
signature: 'Tensor (Tensor x, *, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmb'
bind_python: True

- name: "fused_apply_rotary_emb_grad"
signature: 'Tensor (Tensor x, Tensor dy, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmbGrad'
bind_python: False

- name: "fused_relu_dropout_grad"
signature: "Tensor (Tensor dy, Tensor mask, Float scale) => FusedReluDropoutGrad"
bind_python: False
Expand Down
126 changes: 126 additions & 0 deletions oneflow/core/functional/impl/fused_attention_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,131 @@ class FusedApplyRotaryEmbFunctor {
std::shared_ptr<OpExpr> op_without_position_sinuous_;
};

class FusedApplyRotaryEmbGradFunctor {
public:
FusedApplyRotaryEmbGradFunctor() {
op_with_position_sinuous_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad")
.Input("x")
.Input("dy")
.Input("cos")
.Input("sin")
.Input("position_ids")
.Output("dx")
.Build());
op_with_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad")
.Input("x")
.Input("dy")
.Input("position_ids")
.Output("dx")
.Build());
op_without_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad")
.Input("x")
.Input("dy")
.Input("cos")
.Input("sin")
.Output("dx")
.Build());
op_without_position_sinuous_ = CHECK_JUST(
one::OpBuilder("fused_apply_rotary_emb_grad").Input("x").Input("dy").Output("dx").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& dy, const Optional<one::Tensor>& cos,
const Optional<one::Tensor>& sin,
const Optional<one::Tensor>& position_ids, const std::string& x_layout,
const Optional<std::string>& output_layout, const std::string& mode,
const Optional<int64_t>& tensor_index, const Optional<int64_t>& k_size,
const float base, const Optional<int64_t>& rotary_size) const {
int64_t b = 0, m = 0, h = 0, k = 0;

if (tensor_index) {
CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2))
<< "tensor_index should be set between [0, 2]";
}
CHECK_OR_RETURN((mode == "interval") || (mode == "plane"))
<< "mode should be \"intervel\" or \"plane\"";

JUST(ParseDims("x", *x->shape(), x_layout, Optional<int64_t>(), k_size, &b, &m, &h, &k));

if (k_size) {
CHECK_EQ_OR_RETURN(JUST(k_size), k)
<< "k_size if given should be equal to K of cos, sin and x.";
}
if (rotary_size) {
CHECK_LE_OR_RETURN(JUST(rotary_size), k) << "rotary_size should be no more than k.";
}

int64_t rotary_emd_dim = 1;

if (position_ids) {
CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->NumAxes(), 3)
<< "ndims of position_ids should be equal to 3, either in form of B1M or B2M.";
CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(0), b)
<< "1st dim of position_ids should be equal to B.";
CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(2), m)
<< "3rd dim of position_ids should be equal to M.";
rotary_emd_dim = JUST(position_ids)->shape()->At(1);
CHECK_OR_RETURN(rotary_emd_dim == 1 || rotary_emd_dim == 2)
<< "2nd dim of position_ids should be 1 or 2.";
}

const int64_t actual_rotary_size = rotary_size.value_or(k) / rotary_emd_dim;
CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0)
<< "k ,or rotary_size if given, should be a multiple of 2 * rotary_encoding_dim.";

if (cos && sin) {
CHECK_EQ_OR_RETURN(JUST(cos)->shape()->NumAxes(), 2)
<< "The number of dimensions of cos should be equal to 2.";
CHECK_OR_RETURN(JUST(cos)->shape() == JUST(sin)->shape())
<< "Each dimension of cos & sin should be the same.";
CHECK_EQ_OR_RETURN(JUST(cos)->shape()->At(1), actual_rotary_size)
<< "The 1st dimension of cos & sin should equal to rotary_size // "
"rotary_embedding_dimension.";
} else if (!cos && !sin) {
// do nothing
} else {
UNIMPLEMENTED_THEN_RETURN() << "cos & sin should both be given or not given.";
}

if (!position_ids) {
if (cos && sin) {
CHECK_GE_OR_RETURN(JUST(cos)->shape()->At(0), m)
<< "M of cos & sin should be to no less than "
"M of x when position_ids is not "
"given."; // K of cos & sin is checked
// inside ParseDims
}
}

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("x_layout", "output_layout", "mode",
"tensor_index", "k_size", "base", "rotary_size");
attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0),
k_size.value_or(k), base, rotary_size.value_or(k));

if (position_ids) {
if (cos && sin) {
return OpInterpUtil::Dispatch<Tensor>(
*op_with_position_sinuous_, {x, dy, JUST(cos), JUST(sin), JUST(position_ids)}, attrs);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_with_position_, {x, dy, JUST(position_ids)},
attrs);
}
} else {
if (cos && sin) {
return OpInterpUtil::Dispatch<Tensor>(*op_without_position_, {x, dy, JUST(cos), JUST(sin)},
attrs);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_without_position_sinuous_, {x, dy}, attrs);
}
}
}

private:
std::shared_ptr<OpExpr> op_with_position_;
std::shared_ptr<OpExpr> op_with_position_sinuous_;
std::shared_ptr<OpExpr> op_without_position_;
std::shared_ptr<OpExpr> op_without_position_sinuous_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand All @@ -741,6 +866,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
"FusedMultiHeadAttentionInferenceV2");
m.add_functor<impl::FusedAttentionConcatPastKeyValueFunctor>("FusedAttentionConcatPastKeyValue");
m.add_functor<impl::FusedApplyRotaryEmbFunctor>("FusedApplyRotaryEmb");
m.add_functor<impl::FusedApplyRotaryEmbGradFunctor>("FusedApplyRotaryEmbGrad");
}

} // namespace functional
Expand Down
26 changes: 26 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3979,6 +3979,32 @@ def OneFlow_FusedApplyRotaryEmbOp : OneFlow_BaseOp<"fused_apply_rotary_emb", [No
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedApplyRotaryEmbGradOp : OneFlow_BaseOp<"fused_apply_rotary_emb_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
OneFlow_Tensor:$dy,
Optional<OneFlow_Tensor>:$cos,
Optional<OneFlow_Tensor>:$sin,
Optional<OneFlow_Tensor>:$position_ids
);
let output = (outs
OneFlow_Tensor:$dx
);
let attrs = (ins
DefaultValuedAttr<StrAttr, "\"BHMK\"">:$x_layout,
DefaultValuedAttr<StrAttr, "\"BHMK\"">:$output_layout,
DefaultValuedAttr<StrAttr, "\"plane\"">:$mode,
DefaultValuedAttr<SI64Attr, "0">:$tensor_index,
DefaultValuedAttr<F32Attr, "1e4">:$base,
DefaultValuedAttr<SI64Attr, "0">:$k_size,
DefaultValuedAttr<SI64Attr, "0">:$rotary_size
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_EmbeddingGradOp : OneFlow_BaseOp<"embedding_grad", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
Expand Down
Loading
Loading