|
63 | 63 | #include "flang/Semantics/tools.h"
|
64 | 64 | #include "flang/Support/Version.h"
|
65 | 65 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
| 66 | +#include "mlir/IR/BuiltinAttributes.h" |
66 | 67 | #include "mlir/IR/Matchers.h"
|
67 | 68 | #include "mlir/IR/PatternMatch.h"
|
68 | 69 | #include "mlir/Parser/Parser.h"
|
@@ -2170,32 +2171,54 @@ class FirConverter : public Fortran::lower::AbstractConverter {
|
2170 | 2171 | return builder->createIntegerConstant(loc, controlType, 1); // step
|
2171 | 2172 | }
|
2172 | 2173 |
|
| 2174 | + // For unroll directives without a value, force full unrolling. |
| 2175 | + // For unroll directives with a value, if the value is greater than 1, |
| 2176 | + // force unrolling with the given factor. Otherwise, disable unrolling. |
| 2177 | + mlir::LLVM::LoopUnrollAttr |
| 2178 | + genLoopUnrollAttr(std::optional<std::uint64_t> directiveArg) { |
| 2179 | + mlir::BoolAttr falseAttr = |
| 2180 | + mlir::BoolAttr::get(builder->getContext(), false); |
| 2181 | + mlir::BoolAttr trueAttr = mlir::BoolAttr::get(builder->getContext(), true); |
| 2182 | + mlir::IntegerAttr countAttr; |
| 2183 | + mlir::BoolAttr fullUnrollAttr; |
| 2184 | + bool shouldUnroll = true; |
| 2185 | + if (directiveArg.has_value()) { |
| 2186 | + auto unrollingFactor = directiveArg.value(); |
| 2187 | + if (unrollingFactor == 0 || unrollingFactor == 1) { |
| 2188 | + shouldUnroll = false; |
| 2189 | + } else { |
| 2190 | + countAttr = |
| 2191 | + builder->getIntegerAttr(builder->getI64Type(), unrollingFactor); |
| 2192 | + } |
| 2193 | + } else { |
| 2194 | + fullUnrollAttr = trueAttr; |
| 2195 | + } |
| 2196 | + |
| 2197 | + mlir::BoolAttr disableAttr = shouldUnroll ? falseAttr : trueAttr; |
| 2198 | + return mlir::LLVM::LoopUnrollAttr::get( |
| 2199 | + builder->getContext(), /*disable=*/disableAttr, /*count=*/countAttr, {}, |
| 2200 | + /*full=*/fullUnrollAttr, {}, {}, {}); |
| 2201 | + } |
| 2202 | + |
2173 | 2203 | void addLoopAnnotationAttr(
|
2174 | 2204 | IncrementLoopInfo &info,
|
2175 | 2205 | llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
|
2176 |
| - mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false); |
2177 |
| - mlir::BoolAttr t = mlir::BoolAttr::get(builder->getContext(), true); |
2178 | 2206 | mlir::LLVM::LoopVectorizeAttr va;
|
2179 | 2207 | mlir::LLVM::LoopUnrollAttr ua;
|
2180 | 2208 | bool has_attrs = false;
|
2181 | 2209 | for (const auto *dir : dirs) {
|
2182 | 2210 | Fortran::common::visit(
|
2183 | 2211 | Fortran::common::visitors{
|
2184 | 2212 | [&](const Fortran::parser::CompilerDirective::VectorAlways &) {
|
| 2213 | + mlir::BoolAttr falseAttr = |
| 2214 | + mlir::BoolAttr::get(builder->getContext(), false); |
2185 | 2215 | va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(),
|
2186 |
| - /*disable=*/f, {}, {}, |
2187 |
| - {}, {}, {}, {}); |
| 2216 | + /*disable=*/falseAttr, |
| 2217 | + {}, {}, {}, {}, {}, {}); |
2188 | 2218 | has_attrs = true;
|
2189 | 2219 | },
|
2190 | 2220 | [&](const Fortran::parser::CompilerDirective::Unroll &u) {
|
2191 |
| - mlir::IntegerAttr countAttr; |
2192 |
| - if (u.v.has_value()) { |
2193 |
| - countAttr = builder->getIntegerAttr(builder->getI64Type(), |
2194 |
| - u.v.value()); |
2195 |
| - } |
2196 |
| - ua = mlir::LLVM::LoopUnrollAttr::get( |
2197 |
| - builder->getContext(), /*disable=*/f, /*count*/ countAttr, |
2198 |
| - {}, /*full*/ u.v.has_value() ? f : t, {}, {}, {}); |
| 2221 | + ua = genLoopUnrollAttr(u.v); |
2199 | 2222 | has_attrs = true;
|
2200 | 2223 | },
|
2201 | 2224 | [&](const auto &) {}},
|
|
0 commit comments