Skip to content

Commit 8d7ba4a

Browse files
LukeBoyertensorflower-gardener
authored andcommitted
Add prepare pattern to rewrite TF_Select to TF_SelectV2 in the case of the inputs not all being static and equal.
TF_Select does not support broadcasting whereas TF_SelectV2 does. Broadcasting might be needed at runtime if stapes are not fully known ahead of time so we need TF_SelectV2. See tensorflow/ir/tf_ops.td for more op details. PiperOrigin-RevId: 534142327
1 parent a22b2f2 commit 8d7ba4a

File tree

3 files changed

+62
-14
lines changed

3 files changed

+62
-14
lines changed

tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir

+24
Original file line numberDiff line numberDiff line change
@@ -737,4 +737,28 @@ func.func @UnsupportedGroupConv_DynamicDimAtInputDimThree(%arg0: tensor<?x1x26x?
737737
// CHECK: "tf.Conv2D"
738738
}
739739

740+
func.func @Select_SameStaticShapeUnchanged(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> (tensor<2xf32>) {
741+
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
742+
func.return %0 : tensor<2xf32>
743+
// CHECK-LABEL: Select_SameStaticShapeUnchanged
744+
// CHECK-NOT: "tf.SelectV2"
745+
// CHECK: "tf.Select"
746+
}
747+
748+
func.func @Select_SameStaticShapeUnchangedWithBroadcastedCond(%arg0: tensor<2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> (tensor<2x2xf32>) {
749+
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
750+
func.return %0 : tensor<2x2xf32>
751+
// CHECK-LABEL: SameStaticShapeUnchangedWithBroadcastedCond
752+
// CHECK-NOT: "tf.SelectV2"
753+
// CHECK: "tf.Select"
754+
}
755+
756+
func.func @Select_NotSameStaticShapeRewritesToSelectV2(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> (tensor<*xf32>) {
757+
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
758+
func.return %0 : tensor<*xf32>
759+
// CHECK-LABEL: Select_NotSameStaticShapeRewritesToSelectV2
760+
// CHECK-NOT: "tf.Select"
761+
// CHECK: "tf.SelectV2"
762+
}
763+
740764
}

tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td

+17
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
include "tensorflow/compiler/mlir/tensorflow/transforms/optimize.td"
1717
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
18+
include "tensorflow/compiler/mlir/lite/utils/utils.td"
1819

1920
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
2021

@@ -182,6 +183,22 @@ def LowerTensorScatterAdd: Pat<
182183
(TF_TensorScatterUpdateOp $input, $indices,
183184
(TF_AddOp $updates, (CreateGatherNdOp $updates, $input, $indices)))>;
184185

186+
//===----------------------------------------------------------------------===//
187+
// Select op patterns.
188+
//===----------------------------------------------------------------------===//
189+
190+
def NotOperandOneAndTwoHaveSameStaticShape : Constraint<Neg<CPred<"OpHasSameStaticShapes($0.getDefiningOp(), ArrayRef({1, 2}))">>, "">;
191+
192+
// If a TF_Select op does not have same static shape for it's value operands, its arguments may
193+
// need to be broadcasted at runtime. Only V2 supports this so we rewrite
194+
// before legalization.
195+
// TODO(b/270755115) This is a temporary workaround, investigate if this
196+
// pattern is needed in all situations.
197+
def ConvertNonStaticSelectToSelectV2 : Pat<
198+
(TF_SelectOp:$src $cond, $x, $y),
199+
(TF_SelectV2Op $cond, $x, $y),
200+
[(NotOperandOneAndTwoHaveSameStaticShape $src)]>;
201+
185202
//===----------------------------------------------------------------------===//
186203
// AddV2 op patterns.
187204
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/lite/utils/utils.h

+21-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_
1717
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_
1818

19+
#include <numeric>
20+
1921
#include "llvm/ADT/ArrayRef.h"
2022
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
2123
#include "mlir/IR/Operation.h" // from @llvm-project
@@ -28,27 +30,32 @@ using mlir::Operation;
2830
using mlir::ShapedType;
2931
using mlir::Value;
3032

31-
// Returns true if all tensor value in `values` has static shape and same shape.
32-
inline bool OpHasSameStaticShapes(Operation* op) {
33-
auto values = op->getOperands();
34-
int operand_num = 0;
35-
ArrayRef<int64_t> shape;
36-
for (Value value : values) {
37-
auto shaped_type = value.getType().dyn_cast<ShapedType>();
33+
// Returns true if each Operand at given indices have the same static shape.
34+
inline bool OpHasSameStaticShapes(Operation* op,
35+
llvm::ArrayRef<int> operand_idxs) {
36+
if (op->getNumOperands() == 0 || operand_idxs.empty()) return true;
37+
const int first_opr_idx = operand_idxs[0];
38+
ArrayRef<int64_t> shape =
39+
op->getOperand(first_opr_idx).getType().dyn_cast<ShapedType>().getShape();
40+
for (int opr_idx : operand_idxs) {
41+
Value operand = op->getOperand(opr_idx);
42+
auto shaped_type = operand.getType().dyn_cast<ShapedType>();
3843
if (!shaped_type || !shaped_type.hasStaticShape()) {
3944
return false;
4045
}
41-
if (operand_num == 0) {
42-
shape = shaped_type.getShape();
43-
} else {
44-
if (shape != shaped_type.getShape()) {
45-
return false;
46-
}
46+
if (shape != shaped_type.getShape()) {
47+
return false;
4748
}
48-
++operand_num;
4949
}
5050
return true;
5151
}
52+
53+
// Returns true if each Operand has the same static shape.
54+
inline bool OpHasSameStaticShapes(Operation* op) {
55+
llvm::OwningArrayRef<int> operand_idxs(op->getNumOperands());
56+
std::iota(operand_idxs.begin(), operand_idxs.end(), 0);
57+
return OpHasSameStaticShapes(op, operand_idxs);
58+
}
5259
} // namespace TFL
5360
} // namespace mlir
5461

0 commit comments

Comments
 (0)