Skip to content

Commit e015b4d

Browse files
suxinguotensorflower-gardener
authored andcommitted
Add prepare pattern to rewrite TF_Select to TF_SelectV2 in the case of the inputs not all being static and equal.
TF_Select does not support broadcasting whereas TF_SelectV2 does. Broadcasting might be needed at runtime if stapes are not fully known ahead of time so we need TF_SelectV2. See tensorflow/ir/tf_ops.td for more op details. PiperOrigin-RevId: 534260994
1 parent 02e8cd4 commit e015b4d

File tree

3 files changed

+14
-62
lines changed

3 files changed

+14
-62
lines changed

tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -737,30 +737,6 @@ func.func @UnsupportedGroupConv_DynamicDimAtInputDimThree(%arg0: tensor<?x1x26x?
737737
// CHECK: "tf.Conv2D"
738738
}
739739

740-
func.func @Select_SameStaticShapeUnchanged(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> (tensor<2xf32>) {
741-
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
742-
func.return %0 : tensor<2xf32>
743-
// CHECK-LABEL: Select_SameStaticShapeUnchanged
744-
// CHECK-NOT: "tf.SelectV2"
745-
// CHECK: "tf.Select"
746-
}
747-
748-
func.func @Select_SameStaticShapeUnchangedWithBroadcastedCond(%arg0: tensor<2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> (tensor<2x2xf32>) {
749-
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
750-
func.return %0 : tensor<2x2xf32>
751-
// CHECK-LABEL: SameStaticShapeUnchangedWithBroadcastedCond
752-
// CHECK-NOT: "tf.SelectV2"
753-
// CHECK: "tf.Select"
754-
}
755-
756-
func.func @Select_NotSameStaticShapeRewritesToSelectV2(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> (tensor<*xf32>) {
757-
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
758-
func.return %0 : tensor<*xf32>
759-
// CHECK-LABEL: Select_NotSameStaticShapeRewritesToSelectV2
760-
// CHECK-NOT: "tf.Select"
761-
// CHECK: "tf.SelectV2"
762-
}
763-
764740
func.func @RedundantShapeOp(%shape: tensor<?xi64>, %fill: tensor<f32>) -> (tensor<?xi64>) {
765741
%0 = "tf.Fill"(%shape, %fill) : (tensor<?xi64>, tensor<f32>) -> (tensor<*xf32>)
766742
%1 = "tf.Shape"(%0) : (tensor<*xf32>) -> (tensor<?xi64>)

tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License.
1515

1616
include "tensorflow/compiler/mlir/tensorflow/transforms/optimize.td"
1717
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
18-
include "tensorflow/compiler/mlir/lite/utils/utils.td"
1918

2019
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
2120

@@ -183,22 +182,6 @@ def LowerTensorScatterAdd: Pat<
183182
(TF_TensorScatterUpdateOp $input, $indices,
184183
(TF_AddOp $updates, (CreateGatherNdOp $updates, $input, $indices)))>;
185184

186-
//===----------------------------------------------------------------------===//
187-
// Select op patterns.
188-
//===----------------------------------------------------------------------===//
189-
190-
def NotOperandOneAndTwoHaveSameStaticShape : Constraint<Neg<CPred<"OpHasSameStaticShapes($0.getDefiningOp(), ArrayRef({1, 2}))">>, "">;
191-
192-
// If a TF_Select op does not have same static shape for it's value operands, its arguments may
193-
// need to be broadcasted at runtime. Only V2 supports this so we rewrite
194-
// before legalization.
195-
// TODO(b/270755115) This is a temporary workaround, investigate if this
196-
// pattern is needed in all situations.
197-
def ConvertNonStaticSelectToSelectV2 : Pat<
198-
(TF_SelectOp:$src $cond, $x, $y),
199-
(TF_SelectV2Op $cond, $x, $y),
200-
[(NotOperandOneAndTwoHaveSameStaticShape $src)]>;
201-
202185
//===----------------------------------------------------------------------===//
203186
// AddV2 op patterns.
204187
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/lite/utils/utils.h

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ limitations under the License.
1616
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_
1717
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_
1818

19-
#include <numeric>
20-
2119
#include "llvm/ADT/ArrayRef.h"
2220
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
2321
#include "mlir/IR/Operation.h" // from @llvm-project
@@ -30,32 +28,27 @@ using mlir::Operation;
3028
using mlir::ShapedType;
3129
using mlir::Value;
3230

33-
// Returns true if each Operand at given indices have the same static shape.
34-
inline bool OpHasSameStaticShapes(Operation* op,
35-
llvm::ArrayRef<int> operand_idxs) {
36-
if (op->getNumOperands() == 0 || operand_idxs.empty()) return true;
37-
const int first_opr_idx = operand_idxs[0];
38-
ArrayRef<int64_t> shape =
39-
op->getOperand(first_opr_idx).getType().dyn_cast<ShapedType>().getShape();
40-
for (int opr_idx : operand_idxs) {
41-
Value operand = op->getOperand(opr_idx);
42-
auto shaped_type = operand.getType().dyn_cast<ShapedType>();
31+
// Returns true if all tensor value in `values` has static shape and same shape.
32+
inline bool OpHasSameStaticShapes(Operation* op) {
33+
auto values = op->getOperands();
34+
int operand_num = 0;
35+
ArrayRef<int64_t> shape;
36+
for (Value value : values) {
37+
auto shaped_type = value.getType().dyn_cast<ShapedType>();
4338
if (!shaped_type || !shaped_type.hasStaticShape()) {
4439
return false;
4540
}
46-
if (shape != shaped_type.getShape()) {
47-
return false;
41+
if (operand_num == 0) {
42+
shape = shaped_type.getShape();
43+
} else {
44+
if (shape != shaped_type.getShape()) {
45+
return false;
46+
}
4847
}
48+
++operand_num;
4949
}
5050
return true;
5151
}
52-
53-
// Returns true if each Operand has the same static shape.
54-
inline bool OpHasSameStaticShapes(Operation* op) {
55-
llvm::OwningArrayRef<int> operand_idxs(op->getNumOperands());
56-
std::iota(operand_idxs.begin(), operand_idxs.end(), 0);
57-
return OpHasSameStaticShapes(op, operand_idxs);
58-
}
5952
} // namespace TFL
6053
} // namespace mlir
6154

0 commit comments

Comments
 (0)