Coder Social home page Coder Social logo

bytedance / byteir Goto Github PK

View Code? Open in Web Editor NEW
357.0 11.0 38.0 16.15 MB

A model compilation solution for various hardware

Home Page: https://byteir.ai

License: Apache License 2.0

CMake 0.59% C++ 21.89% C 0.14% MLIR 67.75% Python 4.93% Shell 0.09% LLVM 4.14% Starlark 0.09% Cuda 0.36% Dockerfile 0.02%
llvm mlir mlsys onnx pytorch tensorflow llm

byteir's Introduction

The ByteIR Project

English | 中文

The ByteIR Project is a ByteDance model compilation solution. ByteIR includes compiler, runtime, and frontends, and provides an end-to-end model compilation solution.

Although all ByteIR components (compiler/runtime/frontends) are together to provide an end-to-end solution, and all under the same umbrella of this repository, each component technically can perform independently.

The name ByteIR

The name, ByteIR, comes from a legacy purpose internally.
The ByteIR project is NOT an IR spec definition project. Instead, in most scenarios, ByteIR directly uses several upstream MLIR dialects and Google Mhlo. Most of ByteIR compiler passes are compatible with the selected upstream MLIR dialects and Google Mhlo.

Why ByteIR

  • Enjoy SOTA models: ByteIR maintains the popular frontends to handle lowering many SOTA models into Stablehlo, and also provides a model zoo (release soon) for research or benchmarking purposes.
  • Just work: ByteIR adopts upstream MLIR dialects and Google Mhlo, and provides compatible passes, utilities, and infrastructure for all compiler builders using upstream MLIR. You can mix using ByteIR passes with upstream MLIR or Mhlo passes, or even your own passes to build your pipeline.
  • Bring your own architecture: ByteIR provides rich generic graph-, loop-, tensor-level, optimizations in Mhlo and Linalg, which allow DL ASIC compilers to reuse, and focus only on the last mile for their backends.

Project Status

ByteIR is still in its early phase. In this phase, we are aiming to provide well-defined, necessary building blocks and infrastructure support for model compilation in a wide-range of deep learning accelerators as well as general-purpose CPUs and GPUs. Therefore, highly-tuned kernels for specific architecture might not have been prioritized. For sure, any feedback for prioritizing specific architecture or corresponding contribution are welcome.

ByteIR Compiler is an MLIR-based compiler for CPU/GPU/ASIC.

ByteIR Runtime is a common, lightweight runtime, capable to serving both existing kernels and ByteIR compiler generated kernels.

ByteIR Frontends includes Tensorflow, PyTorch, and ONNX.

Components Communication Interface

Each ByteIR component technically can perform independently. There are pre-defined communication interface between components.

Stablehlo between frontends and compiler

ByteIR frontends and ByteIR compiler communicate through Stablehlo dialect, which version might be updated during development.

This also implies whatever frontend generating Stablehlo with a compatible version can work with ByteIR compiler, and also whatever compiler consuming Stablehlo with a compatible version can work with ByteIR frontends.

ByRE between compiler and runtime

ByteIR compiler and ByteIR runtime communicates through ByRE format, which version might be updated during development. ByRE dialect is defined as a kind of ByRE format in ByteIR compiler, currently supporting emitting a textual form or bytecode with versioning for ByteIR compiler and runtime.

Other ByRE formats are under development.

Publication and Citation

ByteIR is the product of many great researchers and interns in ByteDance. Below is a list of our public talks:

If you find ByteIR useful, please consider citing.

@misc{byteir2023,
title = {{ByteIR}},
author = {Cao, Honghua and Chang, Li-Wen and Chen, Chongsong and Jiang, Chengquan and Jiang, Ziheng and Liu, Liyang and Liu, Yuan and Liu, Yuanqiang and Shen, Chao and Wang, Haoran and Xiao, Jianzhe and Yao, Chengji and Yuan, Hangjian and Zhang, Fucheng and Zhang, Ru and Zhang, Xuanrun and Zhang, Zhekun and Zhang, Zhiwei and Zhu, Hongyu and Liu, Xin},
url = {https://github.com//bytedance/byteir},
year = {2023}
}

The ByteIR Project is under the Apache License v2.0

byteir's People

Contributors

connor-xy avatar eltociear avatar heromapwrd avatar jackychenyi avatar jianwenyyy avatar linuxlonelyeagle avatar liwenchangbdbz avatar mi-jiazhi avatar paran0idy avatar qingyunqu avatar sogartar avatar vremold avatar wenlei-bao avatar xg-zheng avatar xinyu302 avatar yaochengji avatar yellowhch avatar yyp0 avatar zhekunz2 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

byteir's Issues

[BUG] mhlo.batch_norm_inference not support

err message:
Traceback (most recent call last): File "compile_resnet18.py", line 6, in <module> compile("./resnet18.mlirbc", "./resnet18.cuda.mlir", entry_func="forward", target="cuda", verbose=True) File "/usr/local/lib/python3.8/dist-packages/byteir/compile.py", line 220, in compile compile_cuda(module, output_file_path, entry_func, verbose) File "/usr/local/lib/python3.8/dist-packages/byteir/compile.py", line 40, in compile_cuda PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(module.operation) byteir._mlir_libs.MLIRError: Failure while executing pass pipeline: error: "-":111:12: failed to legalize operation 'mhlo.batch_norm_inference' that was explicitly marked illegal note: "-":111:12: see current operation: %205 = "mhlo.batch_norm_inference"(%204, %191, %189, %195, %193) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>

hello, when use byteir compile resnet18, error show not support mhlo.batch_norm_inference , is it supported or I use wrong version , thank you ?

Runtime compilation issues on Ubuntu22.04

Hi, I got two issues when trying to compile runtime on Ubuntu22.04:

  1. log:
/work/dev/byteir/byteir/runtime/include/brt/core/context/execution_frame.h:128:14: error: ‘unique_lock’ is not a member of ‘std’
  128 |         std::unique_lock<std::shared_mutex> lock_write(mutex);

Adding #include <mutex> for /runtime/include/brt/core/context/execution_frame.h can fix this issue.

  1. log:
/work/dev/byteir/byteir/runtime/include/brt/core/framework/op_kernel_info.h:92:20:   required from here
/usr/include/c++/11/bits/stl_pair.h:217:11: error: ‘std::pair<_T1, _T2>::first’ has incomplete type

Adding #include <string> to /runtime/include/brt/core/framework/op_kernel_info.h can fix this issue.

[RFC] Sharding Framework Design for Device Mesh

[RFC] Sharding Framework Design for Device Mesh

The original post of the design doc could be found in byteir repo. And the initial PR is also there.

As many of you are aware, machine learning models continue to grow in size and complexity, the need for efficient and flexible distribution strategies becomes paramount. The device mesh concept, combined with a robust sharding framework, can significantly enhance performance, scalability, and flexibility across various hardware setups.

I'm reaching out to this community for two primary reasons:

  1. Collaboration: If you're already working on a sharding framework or see potential in this framework, I'd love to collaborate. There's immense value in pooling our collective expertise to make sharding framework in MLIR robust and widely applicable.

  2. Feedback: Whether you're an expert in the field or someone with a keen interest, your feedback is invaluable. Constructive criticism, suggestions, or even pointing out potential pitfalls can significantly enhance the quality and applicability of this framework.

Mesh Dialect

The mesh dialect contains a set of attributes, operations, interfaces and transformations that are useful for representing and optimization the computation on device mesh.

MeshShardingAttr

Attribute that extends tensor type to distributed tensor type.

Syntax:

#mesh.shard<
  ::llvm::ArrayRef<::mlir::ArrayAttr>   # axes
>

The mesh.shard attribute is an array composed of int64_t sub-arrays. The outer array's maximum size is the rank of the related tensor plus one. For the i-th sub-array, if its value is [x, y]:

  • When i < rank, it indicates that the tensor's i-th dimension is sharded along the x and y axes of the device mesh.
  • When i == rank, it signifies that the tensor represents a partial sum along the x and y axes. More partial types could be introduced if needed, e.g. partial-max, partial-min.

Example:

// the tensor is sharded on the first dimension along axis 0
tensor<4x8xf32, #mesh.shard<[[0]]>

// the tensor is sharded on the first dimension along axis 0 and it is also
// a partial-sum along axis 1.
tensor<4x8xf32, #mesh.shard<[[0], [], [1]]>

Parameters:

Parameter C++ type Description
axes ::llvm::ArrayRef<::mlir::ArrayAttr>

ShardingIteratorType Enum

Currently there're only three sharding iterator types:

  • parallel: there should be an all-gather along the tensor dimension to get the full tensor.
  • reduction_sum: there should be an all-reduce-sum along the tensor dimension to get the full tensor. Other types of reduction could be introduced when needed, even a generic reduction type, where a payload body indicating what exactly the reduction is needes to be included.
  • invalid: it means the dimension cannot be sharded

mesh.cluster (mesh::ClusterOp)

Representing a mesh cluster

Syntax:

operation ::= `mesh.cluster` $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` attr-dict

The mesh.cluster operation is a symbol operation that identifies a specific mesh cluster, which can be used for distributed computations across a mesh topology. The operation has three attributes:

  1. sym_name: This attribute uniquely identifies the name of the mesh cluster. This name serves as a symbolic reference to the cluster throughout the MLIR module, allowing for consistent referencing and easier debugging.

  2. rank: This attribute specifies the number of axes of the cluster. The rank indicates the dimensionality of the mesh cluster and can be used to determine the layout and the addressing space of the computation distributed across the mesh.

  3. dim_sizes: This attribute represents the device assignment along the axes of the cluster. Each integer in the array corresponds to the number of devices along a specific axis. If an integer value is <= 0, it implies that the number of devices along that axis is unknown. This flexibility allows for dynamic device assignment or configurations where the exact number of ·devices might not be determined during compile time.

Example:

// A device mesh cluster with 3 axes, the totol device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])

// A device mesh cluster with 2 axes, the totol device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])

// A device mesh cluster with 2 axes, the totol device number is unknown
// The first dimension size is unknown and the second is 4
mesh.cluster @mesh1(rank = 2, dim_sizes = [0, 4])

// a func op running on @mesh0
func.func(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> attributes
                                                { mesh_cluster = @mesh0 } {
  ...
}

Interfaces: Symbol

Attributes:

Attribute MLIR Type Description
sym_name ::mlir::StringAttr string attribute
rank ::mlir::IntegerAttr 64-bit signless integer attribute
dim_sizes ::mlir::ArrayAttr 64-bit integer array attribute

mesh.idx (mesh::IdxOp)

Get the index of current device along specified mesh axis.

Syntax:

operation ::= `mesh.idx` attr-dict `:` type($result)

It is used in the SPMD format of IR. Constraints:

  1. The axis mush be non-negative and less than the total number of mesh axes.
  2. Its parent op must be a FuncOp with mesh_cluster attribute

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
axis ::mlir::IntegerAttr index attribute

Results:

Result Description
result Integer-like type with unknown platform-dependent bit width

mesh.size (mesh::SizeOp)

Get the device number along specified mesh axis.

Syntax:

operation ::= `mesh.size` attr-dict `:` type($result)

It is used in the SPMD format of IR.
Constraints:

  1. The axis mush be non-negative and less than the total number of mesh axes.
  2. Its parent op must be a FuncOp with mesh_cluster attribute

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
axis ::mlir::IntegerAttr index attribute

Results:

Result Description
result Integer-like type with unknown platform-dependent bit width

mesh.annotate (mesh::AnnotateOp)

Annotate on how a tensor is sharded across a mesh cluster.

Syntax:

operation ::= `mesh.annotate` $input attr-dict `:` type($input) `->` type($output)

The mesh.annotate operation is designed to specify and guide the sharding
behavior of a tensor value across a mesh topology. It offers both strict
requirements and hints for the sharding process, allowing for flexibility
in distributed computations. This operation has one operand and three
attributes:

  1. input: This operand represents the tensor value that needs to be
    annotated for sharding.

  2. sharding: An array of int64 arrays with a maximum size equal to the
    rank of the input tensor plus one. Each element of the outer array
    corresponds to a dimension of the input tensor, except for the last element
    which signifies the tensor as a partial-sum. Each inner int64 array lists
    the axes to shard on. An axis will be sharded along at most one input
    dimension. If an axis is not present in any of the inner arrays, it
    indicates that the tensor will be replicated along that axis in the mesh.

  3. required: A boolean attribute. When set to true, it mandates the
    compiler to adhere to the sharding annotation specified. If set to false,
    the sharding annotation serves merely as a hint, allowing the compiler
    discretion in optimizing the sharding strategy.

  4. as_result: A boolean attribute addressing the scenario when a tensor's
    sharding annotation differs based on its context of use (either as a result
    or an operand). If true, the annotation applies to the operation that
    defines the tensor value. If false, the annotation pertains to specific
    users of the tensor value, indicating how it should be considered when used
    as an operand in subsequent operations.

Example:

// The first mesh.annotate op applies to op0, the second mesh.annotate op
// applies to op1, the third mesh.annotate op applies to op2
%0 = op0 ...
%1 = mesh.annotate %0 {sharding = [[0], [1]], required = true,
        as_result = true} : tensor<2x5xf32> -> tensor<2x5xf32>
%2 = mesh.annotate %1 {sharding = [[0]], required = true,
        as_result = false} : tensor<2x5xf32> -> tensor<2x5xf32>
%3 = op1(%2) : ...
%4 = mesh.annotate %1 {sharding = [[1]], required = true,
        as_result = false} : tensor<2x5xf32> -> tensor<2x5xf32>
%5 = op2(%4) : ...

// The mesh.annotation op applies to op0, the op1's operand has no
// annotation
%0 = op0 ...
%1 = mesh.annotate %0 {sharding = [[0], [1]], required = true,
        as_result = true} : tensor<2x5xf32> -> tensor<2x5xf32>
%2 = op1(%1) : ...

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
sharding ::mlir::ArrayAttr array attribute
required ::mlir::BoolAttr bool attribute
as_result ::mlir::BoolAttr bool attribute

Operands:

Operand Description
input Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
output Multi-dimensional array with a fixed number of dimensions

mesh.all_gather (mesh::AllGatherOp)

All-gather op in device mesh

Syntax:

operation ::= `mesh.all_gather` $src attr-dict `:` type($src) `->` type($result)

The operation is designed to facilitate all-gather computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has one
attributes:

  1. mesh_axis: An array of int64 array, representing the axes of the device
    mesh where the all-gather operation will be applied.

Example:

%1 = mesh.all_gather %0 {mesh_axis = [[0], [1]]} :
  tensor<2x4xf32, #mesh.shard<[[0], [1]]>> -> tensor<2x4xf32>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
mesh_axis ::mlir::ArrayAttr array attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

mesh.all_reduce (mesh::AllReduceOp)

All-reduce op in device mesh

Syntax:

operation ::= `mesh.all_reduce` $src attr-dict `:` type($src) `->` type($result)

The operation is designed to facilitate all-reduce computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has two
attributes:

  1. mesh_axis: An int64 array representing the axes of the device mesh
    where the all-reduce operation will be applied.

  2. reduction: Indicates the reduction method.

Example:

%1 = mesh.all_reduce %0 {reduction = "sum", mesh_axis = [0]} :
  tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0, 1]]>> -> tensor<2x4x8xf32, #mesh.shard<[[], [], [], [1]]>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
mesh_axis ::mlir::ArrayAttr 64-bit integer array attribute
reduction ::mlir::StringAttr string attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

mesh.all_to_all (mesh::AllToAllOp)

TODO

mesh.local_split (mesh::LocalSplitOp)

Split a ranked tensor locally

Syntax:

operation ::= `mesh.local_split` $src attr-dict `:` type($src) `->` type($result)

The operation represents spliting an ranked tensor locally specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has one
attributes:

  1. sharding: An array of int64 arrays with a maximum size equal to the
    rank of the src tensor. Each element of the outer array corresponds to a
    dimension of the src tensor. Each inner int64 array lists the axes to
    split on. An axis will be sharded along at most one dimension, and it
    should not appears in the MeshShardingAttr of the src tensor.

Example:

%1 = mesh.local_split %0 {sharding = [[], [], [0]]} : tensor<2x4x8xf32, , #mesh.shard<[[], [1], []]>>
       -> tensor<2x4x8xf32, #mesh.shard<[[], [1], [0]]>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferDTensorInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
sharding ::mlir::ArrayAttr array attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

mesh.reduce_scatter (mesh::ReduceScatterOp)

Reduce-scatter op in device mesh

Syntax:

operation ::= `mesh.reduce_scatter` $src attr-dict `:` type($src) `->` type($result)

The operation is designed to facilitate reduce-scatter computations specifically
within the context of a device mesh. It works with tensors that are
distributed across the device mesh, and these tensors are essentially the
builtin ranked tensors extended with the MeshShardingAttr. It has one
attributes:

  1. mesh_axis: An int64 array representing the axes of the device mesh
    where the all-reduce operation will be applied.

  2. reduction: Indicates the reduction method.

  3. tensor_axis: Indicates the axis to scatter.

Example:

%1 = mesh.reduce_scatter %0 {mesh_axis = [0], reduction = "sum", tensor_axis = 2 : i64} :
   tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>> -> tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferDTensorInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute MLIR Type Description
mesh_axis ::mlir::ArrayAttr 64-bit integer array attribute
reduction ::mlir::StringAttr string attribute
tensor_axis ::mlir::IntegerAttr 64-bit signless integer attribute

Operands:

Operand Description
src Multi-dimensional array with a fixed number of dimensions

Results:

Result Description
result Multi-dimensional array with a fixed number of dimensions

ShardingInterface

The ShardingInterface is an interface within MLIR that enables operations to
provide necessary information for sharding. This interface is primarily used for
sharding propagation. The interface is composed of four methods, out of which
two must be overridden for each operation type. The remaining two methods
provide default implementations which suffice for a majority of operation types.

Methods

getLoopIteratorTypes()

  • Description: This method provides a list of iterator types associated with the
    number of loops within the operation.
  • Return Type: SmallVectormlir::mesh::ShardingIteratorType
  • Details: The iterator types can be one of the following:
    • parallel: If the loop is sharded based on this iterator type, a subsequent
      all-gather is required after the sharded operation to produce the complete
      tensor.
    • reduction: When sharded on this loop, a subsequent all-reduce operation is
      essential post the sharded operation to generate the complete tensor.
    • invalid: This signifies that the loop cannot undergo sharding.

getIndexingMaps()

  • Description: Offers the indexing maps associated with the current operation.
  • Return Type: SmallVector
  • Details: These are affine maps, translating between loop iterators and tensor
    indices. Affine maps are formed from linear combinations and constants. The
    indexing maps of the operation results are restricted to projected
    permutations.

getShardingOption(OpBuilder &b)

  • Description: Given that certain operands or results of the operation may be
    annotated, this method leverages this information to deduce how the operation
    should be sharded.
  • Return Type: FailureOr
  • Details: ShardingOption is represented as an array of int64 arrays. The
    sub-array at the i-th position signifies the mesh axes the i-th loop will be
    sharded on.
  • Default implementation logic:
    1. Check for Existing Attribute: If the operation already possesses a
      ShardingOption attribute, return this attribute immediately.
    2. Initialization: Instantiate an empty `ShardingOption``. This should be an
      array containing int64 sub-arrays, each corresponding to a loop in the
      operation.
    3. Results Annotation Handling:
      • Iterate over all the results of the operation, If a result has an
        annotation:
        • Map the tensor dimensions to loop iterators.
        • Set the corresponding axes based on the mapped loop iterators.
        • In cases where there's a conflict with previously set axes, it implies
          an invalid sharding annotation. In such instances, flag this
          inconsistency for subsequent error handling or correction.
    4. Operands Annotation Handling:
      • Iterate over all the operands of the operation, using the information
        from:
        • Reduction iterator loops and
        • Unhandled parallel iterator loops
      • Validate the remaining iterator loops. If discrepancies arise during
        validation, take appropriate corrective actions or raise errors.
    5. Replication of Mesh Axes: Any mesh axes that haven't been addressed or
      mapped during the above steps should be treated as replicated axes.
    6. Return Logic:
      • If the constructed or modified ShardingOption is valid, return it.
      • If inconsistencies or errors were detected, return a `failure()``.

setShardingAnnotations(OpBuilder &b, const ShardingOption &option)

  • Description: Based on a given ShardingOption, this method annotates those
    operands and results which previously lacked sharding annotations.
  • Return Type: LogicalResult
  • Details: The primary role is to propagate sharding annotations throughout
    the operation based on the provided sharding options.
  • Default implementation logic:
    1. Results Annotation Handling: Given the constraints of the result indexing
      maps, which are limited to projected permutations, there can only be a
      single DimId across all the result indexing maps.
    • For parallel loop iterators: Establish and assign the corresponding axes
      based on the mapped loop iterators.
    • For reduction loops: Append additional axes to the end of the existing
      annotations to indicate their association with the reduction loops.
    1. Operands Annotation Handling: Operand annotations pose a more intricate
      challenge compared to results due to the possibility that they might not
      strictly adhere to projected permutations.
    • Here, we constrain the results of the operand's indexing maps to a
      representation format: $c_i * d_i + c_j * d_j + ...$, In this
      representation:
      • $c_i$ and $c_j$ denote constants. If a constant has a value of one, it may
        be excluded from the representation.
      • $​d_i$ and $d_j$ represent the DimId.
    • In situations where the representation contains multiple DimIds:
      Sharding can only be applied to at most one of them. This constraint
      ensures that the operand annotations don't introduce excessive complexity
      and retain predictability in their sharding behavior.

Sharding Propagation Pass

The sharding propagation pass aims to address two primary objectives:

  1. Sharding Annotation Completion: Computational graphs often have incomplete
    sharding annotations. This pass is designed to fill in these gaps.
  2. Distributed Tensor materialization: Once the computational graph is fully
    annotated, this pass will convert it into distributed tensors and incorporate
    concrete communication operations.

Implementation Logic:

  1. Backward Sharding Propagation:
  • Traverse all operations that implement the `ShardingInterface``, iterating
    in reverse order.
  • For each operation, invoke the getShardingOption and
    setShardingAnnotation methods.
  1. Forward Sharding Propagation:
  • Traverse all operations that implement the `ShardingInterface``, but this
    time in a non-reversed (forward) order.
  • Similarly, for each operation, call the getShardingOption and
    setShardingAnnotation methods.
  1. Annotation Operations Handling: Process all annotation operations in reverse
    order
    • Result Annotations (as_result = true): Extend the type of the annotated
      value by incorporating a MeshShardingAttr. This attribute is derived from
      the annotation operation itself.
    • Operand Annotations (as_result = false): Introduce additional communication
      operations. The final produced value will replace the result of the
      original annotation operation. Note: At this stage, the logic for
      communication creation can be kept straightforward. Further
      canonicalization and optimization of these communications can be executed
      later. The process can be categorized into three stages:
      • All-Reduce: If any reduction sharding axes are absent in the
        current annotation operation relative to its operand's defining operation
        (which should also be an annotation operation with as_result = true),
        an all-reduce operation should be initialized.
      • All-Gather: Create an all-gather operation to reconstruct the
        complete tensor.
      • Local-Split: Launch a local-split operation to derive the final
        sharded tensor.

Collective Communication Optimization Passes

After the sharding propagation pass, collective communication optimization aims
to further streamline and optimize communication operations. Some passes are
list below as examples:

All-Reduce Folder Pass

  • Purpose: To consolidate successive all-reduce operations for efficiency.

  • Description: This pass identifies scenarios where one all-reduce operation feeds
    directly into another. When detected, the to-reduce mesh axes are expanded,
    leading to a folded representation and reduced redundancy.

All-Reduce Reassociate Pass

  • Purpose: To streamline multiple all-reduce operations acting on elementwise
    operations.

  • Description: This pass identifies patterns where multiple all-reduce
    operations are applied to the results of elementwise operations. Upon detection,
    the pass reassociates these operations to reduce the number of collective
    communications. For instance, the sequence add(all-reduce(x), all-reduce(y))
    would be transformed into all-reduce(add(x,y)).

Reduce-Scatter Reassociate Pass

  • Purpose: To optimize multiple reduce-scatter operations that act on
    elementwise operations.

  • Description: This pass detects patterns where multiple reduce-scatter
    operations are applied to the results of elementwise operations. When such
    patterns are identified, the pass reassociates these operations to consolidate
    the collective communications. As an example, a sequence like
    add(reduce-scatter(x), reduce-scatter(y)) would be reshaped into
    reduce-scatter(add(x,y)).

All-Gather Move Down Pass

  • Purpose: To reposition all-gather operations for improved efficiency in the
    computational flow.

  • Description: This pass targets scenarios where an all-gather operation
    precedes operations that have a parallel loop type for gathering. In such
    situations, the all-gather operation is shifted downwards in the sequence

Sharding Mutations

While the logic of the sharding propagation pass is designed for simplicity, it doesn't always yield the most optimal outcome. An optional sharding mutation can be introduced to modify the sharding result of the IR. The sharding mutation is usually determined by the analysis of communication and computation based on current sharded IR.

Sharding Partition

This pass transforms distributed tensors into specific tensors for each device. Additionally, it converts mesh CCL operations into more defined CCL ops with device IDs. Based on various scenarios, we can opt for one of two types of result IR:

  1. The physical weights/arguments are already partitioned across different devices, eliminating the need to retain this information in the IR. In this case,
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<8xf32, #mesh.shard<[[0]]>>) -> () attributes { mesh_cluster = @mesh0 } {
  "use"(%arg0) ...
  ...
}

// will be converted to 

mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<4xf32>) -> () attributes { mesh_cluster = @mesh0 } {
  "use"(%arg0) ...
  ...
}

  1. The physical weights/arguments haven't been partitioned for individual devices, necessitating knowledge of the actual slice information. In this case,
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<8xf32, #mesh.shard<[[0]]>>) -> () attributes { mesh_cluster = @mesh0 } {
  "use"(%arg0) ...
  ...
}

// will be converted to 

mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func(%arg0: tensor<8xf32>) -> () attributes { mesh_cluster = @mesh0 } {
  %idx = mesh.idx(0)
  %c4 = arith.constant 4 : i64
  %start = arith.muli %idx, %c4 : i64
  %arg0_slice = "mhlo.dynamic_slice"(%arg0, %start) {
    slice_sizes = dense<[4]> : tensor<1xi64>
  } : (tensor<8xf32>, i64) -> tensor<4xf32>
  "use"(%arg0_slice) ...
  ...
}

End2End Walkthrough Example

MLP 1D weight stationary with all sharding annotation on tensors

https://arxiv.org/abs/2211.05102, figure 2(a)
Screenshot 2023-09-04 at 7 58 12 AM

  1. Original IR
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func @mlp(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) -> tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
  %0 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %1 = "mhlo.dot_general"(%0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config =
                                      [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
  %4 = "mhlo.dot_general"(%3, %arg2) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
  %5 = mesh.annotate %4 {required = true, sharding = [[], [], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %6 = mesh.annotate %5 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  return %6 : tensor<2x4x8xf32>
}
  1. Loop types and indexing maps
%3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
loop types: [parallel parallel parallel ]
indexing maps:
(d0, d1, d2) -> (d0, d1, d2)
(d0, d1, d2) -> (d0, d1, d2)
(d0, d1, d2) -> (d0, d1, d2)
%1 = "mhlo.dot_general"(%0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config =
                                      [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
loop types: [parallel parallel parallel reduction_sum ]
indexing maps:
(d0, d1, d2, d3) -> (d0, d1, d3)
(d0, d1, d2, d3) -> (d3, d2)
(d0, d1, d2, d3) -> (d0, d1, d2)
%4 = "mhlo.dot_general"(%3, %arg2) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>,
                                      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
loop types: [parallel parallel parallel reduction_sum ]
indexing maps:
(d0, d1, d2, d3) -> (d0, d1, d3)
(d0, d1, d2, d3) -> (d3, d2)
(d0, d1, d2, d3) -> (d0, d1, d2)
  1. After annotation completion
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) -> tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
  %0 = mesh.annotate %arg1 {as_result = false, required = false, sharding = [[], [0]]} : tensor<8x32xf32> -> tensor<8x32xf32>
  %1 = mesh.annotate %arg2 {as_result = false, required = false, sharding = [[0]]} : tensor<32x8xf32> -> tensor<32x8xf32>
  %2 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %3 = mesh.annotate %2 {as_result = false, required = false, sharding = []} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %4 = "mhlo.dot_general"(%3, %0) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
  %5 = mesh.annotate %4 {required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %6 = mesh.annotate %5 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %7 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %8 = mesh.annotate %7 {required = false, sharding = []} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %9 = mesh.annotate %8 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %10 = mhlo.maximum %6, %9 {sharding = [[], [], [0]]} : tensor<2x4x32xf32>
  %11 = mesh.annotate %10 {required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %12 = mesh.annotate %11 {as_result = false, required = false, sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32>
  %13 = "mhlo.dot_general"(%12, %1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
  %14 = mesh.annotate %13 {required = true, sharding = [[], [], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %15 = mesh.annotate %14 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  return %15 : tensor<2x4x8xf32>
}
  1. After sharding materialization
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

  func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>, %arg1: tensor<8x32xf32, #mesh.shard<[[], [0]]>>, %arg2: tensor<32x8xf32, #mesh.shard<[[0]]>>) -> tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>> attributes { mesh_cluster = @mesh0 } {
    %0 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
    %1 = mesh.all_gather %arg0 {mesh_axis = [[], [], [0]], tensor_axis = [2]} : tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>> -> tensor<2x4x8xf32>
    %2 = "mhlo.dot_general"(%1, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x32xf32, #mesh.shard<[[], [0]]>>) -> tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
    %3 = mesh.local_split %0 {sharding = [[], [], [0]]} : tensor<2x4x32xf32> -> tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
    %4 = mhlo.maximum %2, %3 {sharding = [[], [], [0]]} : tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>
    %5 = "mhlo.dot_general"(%4, %arg2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x32xf32, #mesh.shard<[[], [], [0]]>>, tensor<32x8xf32, #mesh.shard<[[0]]>>) -> tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>>
    %6 = mesh.reduce_scatter %5 {mesh_axis = [0], reduction = "sum", tensor_axis = 2 : i64} : tensor<2x4x8xf32, #mesh.shard<[[], [], [], [0]]>> -> tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
    return %6 : tensor<2x4x8xf32, #mesh.shard<[[], [], [0]]>>
  }
  1. After sharding partition
mesh.cluster @mesh0(rank = 1, dim_sizes = [2])

func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x4xf32>, %arg1: tensor<8x16xf32>, %arg2: tensor<16x8xf32>) -> tensor<2x4x4xf32> attributes {mesh_cluster = @mesh0} {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %1 = "mhlo.all_gather"(%arg0) {all_gather_dim = 2 : i64, channel_handle = #mhlo.channel_handle<handle = 0, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<2x4x4xf32>) -> tensor<2x4x8xf32>
  %2 = "mhlo.dot_general"(%1, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [0]]} : (tensor<2x4x8xf32>, tensor<8x16xf32>) -> tensor<2x4x16xf32>
  %3 = mhlo.constant dense<0.000000e+00> : tensor<2x4x16xf32>
  %4 = mhlo.maximum %2, %3 {sharding = [[], [], [0]]} : tensor<2x4x16xf32>
  %5 = "mhlo.dot_general"(%4, %arg2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], sharding = [[], [], [], [0]]} : (tensor<2x4x16xf32>, tensor<16x8xf32>) -> tensor<2x4x8xf32>
  %6 = "mhlo.reduce_scatter"(%5) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %7 = mhlo.add %arg3, %arg4 : tensor<f32>
    mhlo.return %7 : tensor<f32>
  }) {channel_handle = #mhlo.channel_handle<handle = 0, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, scatter_dimension = 2 : i64} : (tensor<2x4x8xf32>) -> tensor<2x4x4xf32>
  return %6 : tensor<2x4x4xf32>
}

MLP 1D weight stationary with sharding option on operations

  1. Original IR
mesh.cluster @mesh0(rank = 1, dim_sizes = [8])

func.func @mlp_1d_weight_stationary_with_sharding_on_operation(%arg0: tensor<2x4x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x8xf32>) -> tensor<2x4x8xf32> attributes { mesh_cluster = @mesh0 } {
  %0 = mesh.annotate %arg0 {required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  %1 = "mhlo.dot_general"(%0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2], 
                                      rhs_contracting_dimensions = [0]>, 
                                      precision_config = 
                                      [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<2x4x8xf32>, tensor<8x32xf32>) -> tensor<2x4x32xf32>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<2x4x32xf32>
  %3 = mhlo.maximum %1, %2 : tensor<2x4x32xf32>
  %4 = "mhlo.dot_general"(%3, %arg2) {
    dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [2],
                                      rhs_contracting_dimensions = [0]>, 
                                      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    sharding = [[], [], [], [0]]
  } : (tensor<2x4x32xf32>, tensor<32x8xf32>) -> tensor<2x4x8xf32>
  %6 = mesh.annotate %4 {as_result = false, required = true, sharding = [[], [], [0]]} : tensor<2x4x8xf32> -> tensor<2x4x8xf32>
  return %6 : tensor<2x4x8xf32>
}
  1. Result of the pass is same as the first example

Q & A

What are the differences and connections between MeshShardingAttr, mesh.annotate, and operation ShardingOption?

  1. MeshShardingAttr and mesh.annotate:
  • Purpose: Both aim to represent distributed tensors.
  • Differences: MeshShardingAttr serves as an optional encoding for
    RankedTensorType, offering a more concise expression. In contrast,
    mesh.annotate introduces an additional operation, ensuring that information
    isn't lost after executing a pass.
  1. operation ShardingOption:
  • This pertains to the sharding annotations of an operation. It more precisely
    depicts how an operation is sharded. Due to its need for a deeper
    understanding of operation computations, it isn't typically exposed to
    end-users.

Why isn't the framework built based on the upstream TilingInterface system?

While utilizing the tiling interface might seem like a logical choice, it
actually introduces certain intricacies. This interface would encode specific
slices directly into the IR. However, during the sharding optimization phase,
this specific slice information isn't necessary. Including it could
inadvertently complicate the implementation.

For instance, sharding naturally denotes that each device has an equal logical
division. If we then dive deeper into lower-level operations like
tensor.extract_slice/insert_slice, this even distribution information could
potentially be lost, which is counterintuitive for sharding.

How do we represent MOE (expert parallel)?

To represent MOE (expert parallel), a tensor would introduce an additional
expert dimension. Alongside, there should be a concept of a "virtual axis".
This comes into play especially when the number of experts is less than the
number of devices on the physical axis. The approach is somewhat akin to how
it's done with tensor parallel.

How does the sharding framework compare to XLA's GSPMD?

  1. Unified Interface: For the majority of operations, only getIndexingMaps and
    getLoopIteratorTypes need to be implemented.
  2. Strategy Independence in Propagation: The propagation phase doesn't employ
    sharding optimization strategies.
  3. Sharding Option at Operation Level: This provides an option for sharding at
    the operation level, making it more convenient for automatic parallel algorithms
    to set sharding strategies with precision.
  4. Explicit Communication Post Propagation: The results after propagation
    explicitly depict communication, facilitating efficient analysis and
    optimization.

[Compiler] FuseExt bug on 5D elementwise tiling

Original MLIR:

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
module {
  func.func @Unknown41(%arg0: tensor<15x4x32x32x1xf16>, %arg1: tensor<15x4x32x32x1xf16>, %arg2: tensor<15x4x32x32xf16>, %arg3: tensor<15x4x32x32xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) attributes {__byteir_elementwise_fusion__} {
    %expanded = tensor.expand_shape %arg2 [[0], [1], [2], [3, 4]] : tensor<15x4x32x32xf16> into tensor<15x4x32x32x1xf16>
    %expanded_0 = tensor.expand_shape %arg3 [[0], [1], [2], [3, 4]] : tensor<15x4x32x32xf16> into tensor<15x4x32x32x1xf16>
    %0 = tensor.empty() : tensor<15x4x32x32x1xf16>
    %1:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %expanded_0, %expanded, %expanded_0, %expanded : tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) outs(%0, %0 : tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) attrs =  {__byteir_gpu_tile_elementwise_0} {
    ^bb0(%in: f16, %in_1: f16, %in_2: f16, %in_3: f16, %in_4: f16, %in_5: f16, %out: f16, %out_6: f16):
      %2 = arith.mulf %in_1, %in_2 : f16
      %3 = arith.mulf %in, %in_3 : f16
      %4 = arith.subf %3, %2 : f16
      %5 = arith.mulf %in, %in_4 : f16
      %6 = arith.mulf %in_1, %in_5 : f16
      %7 = arith.addf %6, %5 : f16
      linalg.yield %4, %7 : f16, f16
    } -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
    return %1#0, %1#1 : tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>
  }
  transform.sequence  failures(propagate) {
  ^bb0(%arg0: !pdl.operation):
    %0 = transform.structured.match attributes {__byteir_gpu_tile_elementwise_0} in %arg0 : (!pdl.operation) -> !pdl.operation
    %transformed, %loops:5 = transform.structured.fuse_ext %0 {tile_interchange = [], tile_sizes = [1, 1, 1, 1, 1]}
  }
}

Command to reproduce:

./build/bin/byteir-opt -transform-dialect-interpreter input.mlir

Full Error message:

../test.mlir:5:8: error: 'tensor.expand_shape' op invalid to have a single dimension (3) expanded into multiple dynamic dims (3,4)
  %3 = mhlo.multiply %1, %arg3 : tensor<15x4x32x32xf16>
       ^
../test.mlir:5:8: note: see current operation: %60 = "tensor.expand_shape"(%59) <{reassociation = [[0], [1], [2], [3, 4]]}> : (tensor<?x?x?x?xf16>) -> tensor<?x?x?x?x?xf16>
// -----// IR Dump After TransformDialectInterpreter Failed (transform-dialect-interpreter) //----- //
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map1 = affine_map<(d0) -> (d0)>
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32xf16>, tensor<15x4x32x32xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>), sym_name = "Unknown41"}> ({
  ^bb0(%arg0: tensor<15x4x32x32x1xf16>, %arg1: tensor<15x4x32x32x1xf16>, %arg2: tensor<15x4x32x32xf16>, %arg3: tensor<15x4x32x32xf16>):
    %0 = "arith.constant"() <{value = 1 : index}> : () -> index
    %1 = "arith.constant"() <{value = 1 : index}> : () -> index
    %2 = "arith.constant"() <{value = 1 : index}> : () -> index
    %3 = "arith.constant"() <{value = 1 : index}> : () -> index
    %4 = "arith.constant"() <{value = 1 : index}> : () -> index
    %5 = "tensor.expand_shape"(%arg2) <{reassociation = [[0], [1], [2], [3, 4]]}> : (tensor<15x4x32x32xf16>) -> tensor<15x4x32x32x1xf16>
    %6 = "tensor.expand_shape"(%arg3) <{reassociation = [[0], [1], [2], [3, 4]]}> : (tensor<15x4x32x32xf16>) -> tensor<15x4x32x32x1xf16>
    %7 = "tensor.empty"() : () -> tensor<15x4x32x32x1xf16>
    %8:2 = "linalg.generic"(%arg0, %arg1, %6, %5, %6, %5, %7, %7) <{indexing_maps = [#map, #map, #map, #map, #map, #map, #map, #map], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operand_segment_sizes = array<i32: 6, 2>}> ({
    ^bb0(%arg4: f16, %arg5: f16, %arg6: f16, %arg7: f16, %arg8: f16, %arg9: f16, %arg10: f16, %arg11: f16):
      %12 = "arith.mulf"(%arg5, %arg6) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
      %13 = "arith.mulf"(%arg4, %arg7) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
      %14 = "arith.subf"(%13, %12) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
      %15 = "arith.mulf"(%arg4, %arg8) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
      %16 = "arith.mulf"(%arg5, %arg9) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
      %17 = "arith.addf"(%16, %15) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
      "linalg.yield"(%14, %17) : (f16, f16) -> ()
    }) {__byteir_gpu_tile_elementwise_0} : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
    %9 = "arith.constant"() <{value = 0 : index}> : () -> index
    %10 = "arith.constant"() <{value = 15 : index}> : () -> index
    %11:2 = "scf.for"(%9, %10, %0, %7, %7) ({
    ^bb0(%arg4: index, %arg5: tensor<15x4x32x32x1xf16>, %arg6: tensor<15x4x32x32x1xf16>):
      %12 = "arith.constant"() <{value = 0 : index}> : () -> index
      %13 = "arith.constant"() <{value = 4 : index}> : () -> index
      %14:2 = "scf.for"(%12, %13, %1, %arg5, %arg6) ({
      ^bb0(%arg7: index, %arg8: tensor<15x4x32x32x1xf16>, %arg9: tensor<15x4x32x32x1xf16>):
        %15 = "arith.constant"() <{value = 0 : index}> : () -> index
        %16 = "arith.constant"() <{value = 32 : index}> : () -> index
        %17:2 = "scf.for"(%15, %16, %2, %arg8, %arg9) ({
        ^bb0(%arg10: index, %arg11: tensor<15x4x32x32x1xf16>, %arg12: tensor<15x4x32x32x1xf16>):
          %18 = "arith.constant"() <{value = 0 : index}> : () -> index
          %19 = "arith.constant"() <{value = 32 : index}> : () -> index
          %20:2 = "scf.for"(%18, %19, %3, %arg11, %arg12) ({
          ^bb0(%arg13: index, %arg14: tensor<15x4x32x32x1xf16>, %arg15: tensor<15x4x32x32x1xf16>):
            %21 = "arith.constant"() <{value = 0 : index}> : () -> index
            %22 = "arith.constant"() <{value = 1 : index}> : () -> index
            %23:2 = "scf.for"(%21, %22, %4, %arg14, %arg15) ({
            ^bb0(%arg16: index, %arg17: tensor<15x4x32x32x1xf16>, %arg18: tensor<15x4x32x32x1xf16>):
              %24 = "tensor.extract_slice"(%arg0, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %25 = "tensor.extract_slice"(%arg1, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %26 = "arith.constant"() <{value = 1 : index}> : () -> index
              %27 = "arith.constant"() <{value = 1 : index}> : () -> index
              %28 = "arith.constant"() <{value = 1 : index}> : () -> index
              %29 = "arith.constant"() <{value = 1 : index}> : () -> index
              %30 = "arith.constant"() <{value = 1 : index}> : () -> index
              %31 = "arith.addi"(%arg4, %26) : (index, index) -> index
              %32 = "arith.addi"(%arg7, %27) : (index, index) -> index
              %33 = "arith.addi"(%arg10, %28) : (index, index) -> index
              %34 = "arith.addi"(%arg13, %29) : (index, index) -> index
              %35 = "arith.addi"(%arg16, %30) : (index, index) -> index
              %36 = "arith.constant"() <{value = 1 : index}> : () -> index
              %37 = "arith.constant"() <{value = 1 : index}> : () -> index
              %38 = "arith.constant"() <{value = 1 : index}> : () -> index
              %39 = "arith.constant"() <{value = 1 : index}> : () -> index
              %40 = "arith.constant"() <{value = 1 : index}> : () -> index
              %41 = "arith.addi"(%arg4, %36) : (index, index) -> index
              %42 = "arith.maxsi"(%31, %41) : (index, index) -> index
              %43 = "arith.addi"(%arg7, %37) : (index, index) -> index
              %44 = "arith.maxsi"(%32, %43) : (index, index) -> index
              %45 = "arith.addi"(%arg10, %38) : (index, index) -> index
              %46 = "arith.maxsi"(%33, %45) : (index, index) -> index
              %47 = "arith.addi"(%arg13, %39) : (index, index) -> index
              %48 = "arith.maxsi"(%34, %47) : (index, index) -> index
              %49 = "arith.addi"(%arg16, %40) : (index, index) -> index
              %50 = "arith.maxsi"(%35, %49) : (index, index) -> index
              %51 = "arith.subi"(%42, %arg4) : (index, index) -> index
              %52 = "arith.subi"(%44, %arg7) : (index, index) -> index
              %53 = "arith.subi"(%46, %arg10) : (index, index) -> index
              %54 = "arith.subi"(%48, %arg13) : (index, index) -> index
              %55 = "arith.subi"(%50, %arg16) : (index, index) -> index
              %56 = "affine.apply"(%arg4) <{map = #map1}> : (index) -> index
              %57 = "affine.apply"(%arg7) <{map = #map1}> : (index) -> index
              %58 = "affine.apply"(%arg10) <{map = #map1}> : (index) -> index
              %59 = "tensor.extract_slice"(%arg3, %56, %57, %58, %arg13, %51, %52, %53, %54) <{operand_segment_sizes = array<i32: 1, 4, 4, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<15x4x32x32xf16>, index, index, index, index, index, index, index, index) -> tensor<?x?x?x?xf16>
              %60 = "tensor.expand_shape"(%59) <{reassociation = [[0], [1], [2], [3, 4]]}> : (tensor<?x?x?x?xf16>) -> tensor<?x?x?x?x?xf16>
              %61 = "arith.constant"() <{value = 0 : index}> : () -> index
              %62 = "arith.constant"() <{value = 0 : index}> : () -> index
              %63 = "arith.constant"() <{value = 0 : index}> : () -> index
              %64 = "arith.constant"() <{value = 0 : index}> : () -> index
              %65 = "arith.constant"() <{value = 0 : index}> : () -> index
              %66 = "tensor.extract_slice"(%60) <{operand_segment_sizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<?x?x?x?x?xf16>) -> tensor<1x1x1x1x1xf16>
              %67 = "arith.constant"() <{value = 0 : index}> : () -> index
              %68 = "arith.constant"() <{value = 0 : index}> : () -> index
              %69 = "arith.constant"() <{value = 0 : index}> : () -> index
              %70 = "arith.constant"() <{value = 0 : index}> : () -> index
              %71 = "arith.constant"() <{value = 0 : index}> : () -> index
              %72 = "tensor.extract_slice"(%60) <{operand_segment_sizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<?x?x?x?x?xf16>) -> tensor<1x1x1x1x1xf16>
              %73 = "tensor.extract_slice"(%6, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %74 = "arith.constant"() <{value = 1 : index}> : () -> index
              %75 = "arith.constant"() <{value = 1 : index}> : () -> index
              %76 = "arith.constant"() <{value = 1 : index}> : () -> index
              %77 = "arith.constant"() <{value = 1 : index}> : () -> index
              %78 = "arith.constant"() <{value = 1 : index}> : () -> index
              %79 = "arith.addi"(%arg4, %74) : (index, index) -> index
              %80 = "arith.addi"(%arg7, %75) : (index, index) -> index
              %81 = "arith.addi"(%arg10, %76) : (index, index) -> index
              %82 = "arith.addi"(%arg13, %77) : (index, index) -> index
              %83 = "arith.addi"(%arg16, %78) : (index, index) -> index
              %84 = "arith.constant"() <{value = 1 : index}> : () -> index
              %85 = "arith.constant"() <{value = 1 : index}> : () -> index
              %86 = "arith.constant"() <{value = 1 : index}> : () -> index
              %87 = "arith.constant"() <{value = 1 : index}> : () -> index
              %88 = "arith.constant"() <{value = 1 : index}> : () -> index
              %89 = "arith.addi"(%arg4, %84) : (index, index) -> index
              %90 = "arith.maxsi"(%79, %89) : (index, index) -> index
              %91 = "arith.addi"(%arg7, %85) : (index, index) -> index
              %92 = "arith.maxsi"(%80, %91) : (index, index) -> index
              %93 = "arith.addi"(%arg10, %86) : (index, index) -> index
              %94 = "arith.maxsi"(%81, %93) : (index, index) -> index
              %95 = "arith.addi"(%arg13, %87) : (index, index) -> index
              %96 = "arith.maxsi"(%82, %95) : (index, index) -> index
              %97 = "arith.addi"(%arg16, %88) : (index, index) -> index
              %98 = "arith.maxsi"(%83, %97) : (index, index) -> index
              %99 = "arith.subi"(%90, %arg4) : (index, index) -> index
              %100 = "arith.subi"(%92, %arg7) : (index, index) -> index
              %101 = "arith.subi"(%94, %arg10) : (index, index) -> index
              %102 = "arith.subi"(%96, %arg13) : (index, index) -> index
              %103 = "arith.subi"(%98, %arg16) : (index, index) -> index
              %104 = "affine.apply"(%arg4) <{map = #map1}> : (index) -> index
              %105 = "affine.apply"(%arg7) <{map = #map1}> : (index) -> index
              %106 = "affine.apply"(%arg10) <{map = #map1}> : (index) -> index
              %107 = "tensor.extract_slice"(%arg2, %104, %105, %106, %arg13, %99, %100, %101, %102) <{operand_segment_sizes = array<i32: 1, 4, 4, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<15x4x32x32xf16>, index, index, index, index, index, index, index, index) -> tensor<?x?x?x?xf16>
              %108 = "tensor.expand_shape"(%107) <{reassociation = [[0], [1], [2], [3, 4]]}> : (tensor<?x?x?x?xf16>) -> tensor<?x?x?x?x?xf16>
              %109 = "arith.constant"() <{value = 0 : index}> : () -> index
              %110 = "arith.constant"() <{value = 0 : index}> : () -> index
              %111 = "arith.constant"() <{value = 0 : index}> : () -> index
              %112 = "arith.constant"() <{value = 0 : index}> : () -> index
              %113 = "arith.constant"() <{value = 0 : index}> : () -> index
              %114 = "tensor.extract_slice"(%108) <{operand_segment_sizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<?x?x?x?x?xf16>) -> tensor<1x1x1x1x1xf16>
              %115 = "arith.constant"() <{value = 0 : index}> : () -> index
              %116 = "arith.constant"() <{value = 0 : index}> : () -> index
              %117 = "arith.constant"() <{value = 0 : index}> : () -> index
              %118 = "arith.constant"() <{value = 0 : index}> : () -> index
              %119 = "arith.constant"() <{value = 0 : index}> : () -> index
              %120 = "tensor.extract_slice"(%108) <{operand_segment_sizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<?x?x?x?x?xf16>) -> tensor<1x1x1x1x1xf16>
              %121 = "tensor.extract_slice"(%5, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %122 = "tensor.extract_slice"(%6, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %123 = "tensor.extract_slice"(%5, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %124 = "tensor.extract_slice"(%7, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %125 = "tensor.extract_slice"(%7, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<1x1x1x1x1xf16>
              %126:2 = "linalg.generic"(%24, %25, %66, %114, %72, %120, %124, %125) <{indexing_maps = [#map, #map, #map, #map, #map, #map, #map, #map], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operand_segment_sizes = array<i32: 6, 2>}> ({
              ^bb0(%arg19: f16, %arg20: f16, %arg21: f16, %arg22: f16, %arg23: f16, %arg24: f16, %arg25: f16, %arg26: f16):
                %129 = "arith.mulf"(%arg20, %arg21) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
                %130 = "arith.mulf"(%arg19, %arg22) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
                %131 = "arith.subf"(%130, %129) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
                %132 = "arith.mulf"(%arg19, %arg23) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
                %133 = "arith.mulf"(%arg20, %arg24) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
                %134 = "arith.addf"(%133, %132) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
                "linalg.yield"(%131, %134) : (f16, f16) -> ()
              }) {__byteir_gpu_tile_elementwise_0} : (tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>) -> (tensor<1x1x1x1x1xf16>, tensor<1x1x1x1x1xf16>)
              %127 = "tensor.insert_slice"(%126#0, %arg17, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<1x1x1x1x1xf16>, tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<15x4x32x32x1xf16>
              %128 = "tensor.insert_slice"(%126#1, %arg18, %arg4, %arg7, %arg10, %arg13, %arg16) <{operand_segment_sizes = array<i32: 1, 1, 5, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: 1, 1, 1, 1, 1>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (tensor<1x1x1x1x1xf16>, tensor<15x4x32x32x1xf16>, index, index, index, index, index) -> tensor<15x4x32x32x1xf16>
              "scf.yield"(%127, %128) : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> ()
            }) : (index, index, index, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
            "scf.yield"(%23#0, %23#1) : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> ()
          }) : (index, index, index, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
          "scf.yield"(%20#0, %20#1) : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> ()
        }) : (index, index, index, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
        "scf.yield"(%17#0, %17#1) : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> ()
      }) : (index, index, index, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
      "scf.yield"(%14#0, %14#1) : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> ()
    }) : (index, index, index, tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>)
    "func.return"(%11#0, %11#1) : (tensor<15x4x32x32x1xf16>, tensor<15x4x32x32x1xf16>) -> ()
  }) {__byteir_elementwise_fusion__} : () -> ()
}) : () -> ()

cc: @liwenchangbdbz @yaochengji

[MemrefToByre] Failed to convert memref.subview

I'm trying a simple concatenate 2 integers case, and failed compilation. The bug seems to be inside MemrefCopyToLinalgPass, that the second memref.copy failed to convert to Linalg. cc: @liwenchangbdbz @xrzhang111

Input mhlo:

func.func @forward(%arg0: tensor<i64>, %arg1: tensor<i64>) -> (tensor<2xi64>) {
  %0 = mhlo.reshape %arg0 : (tensor<i64>) -> tensor<1xi64>
  %1 = mhlo.reshape %arg1 : (tensor<i64>) -> tensor<1xi64>
  %2 = "mhlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
  return %2 : tensor<2xi64>
}

IR before MemrefCopyToLinalgPass:

module {
  func.func @forward(%arg0: memref<i64>, %arg1: memref<i64>) -> memref<2xi64> attributes {__placeholder__byre.entry_point} {
    %expand_shape = memref.expand_shape %arg0 [] : memref<i64> into memref<1xi64>
    %expand_shape_0 = memref.expand_shape %arg1 [] : memref<i64> into memref<1xi64>
    %alloc = memref.alloc() : memref<2xi64>
    %subview = memref.subview %alloc[0] [1] [1] : memref<2xi64> to memref<1xi64, strided<[1]>>
    memref.copy %expand_shape, %subview : memref<1xi64> to memref<1xi64, strided<[1]>>
    %subview_1 = memref.subview %alloc[1] [1] [1] : memref<2xi64> to memref<1xi64, strided<[1], offset: 1>>
    memref.copy %expand_shape_0, %subview_1 : memref<1xi64> to memref<1xi64, strided<[1], offset: 1>>
    return %alloc : memref<2xi64>
  }
}

IR After MemrefCopyToLinalgPass:

module {
  func.func private @memref_copy_kernel(%arg0: memref<i64>, %arg1: memref<2xi64>) attributes {__byteir_elementwise_fusion__} {
    %subview = memref.subview %arg1[1] [1] [1] : memref<2xi64> to memref<1xi64, strided<[1], offset: 1>>
    %expand_shape = memref.expand_shape %arg0 [] : memref<i64> into memref<1xi64>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%expand_shape : memref<1xi64>) outs(%subview : memref<1xi64, strided<[1], offset: 1>>) {
    ^bb0(%in: i64, %out: i64):
      linalg.yield %in : i64
    }
    return
  }
  func.func @forward(%arg0: memref<i64>, %arg1: memref<i64>) -> memref<2xi64> attributes {__placeholder__byre.entry_point} {
    %expand_shape = memref.expand_shape %arg0 [] : memref<i64> into memref<1xi64>
    %alloc = memref.alloc() : memref<2xi64>
    %subview = memref.subview %alloc[0] [1] [1] : memref<2xi64> to memref<1xi64, strided<[1]>>
    memref.copy %expand_shape, %subview : memref<1xi64> to memref<1xi64, strided<[1]>>
    call @memref_copy_kernel(%arg1, %alloc) {num_readonly_operand = 1 : index} : (memref<i64>, memref<2xi64>) -> ()
    return %alloc : memref<2xi64>
  }
}

[Compiler GPU] LLVM ERROR: operation destroyed but still has uses

I'm getting LLVM ERROR: operation destroyed but still has uses when running gpu-opt pipeline. The erroring pass is ConvertFuncToGPUPass.

error msg:

../test.mlir:374:13: error: 'scf.for' op operation destroyed but still has uses
    %0:28 = scf.for %arg75 = %c0 to %c256 step %c1 iter_args(%arg76 = %alloc_0, %arg77 = %alloc_3, %arg78 = %alloc_4, %arg79 = %alloc_5, %arg80 = %alloc_6, %arg81 = %alloc_7, %arg82 = %alloc_8, %arg83 = %alloc_9, %arg84 = %alloc_10, %arg85 = %alloc_11, %arg86 = %alloc_12, %arg87 = %alloc_13, %arg88 = %alloc_14, %arg89 = %alloc_15, %arg90 = %alloc_16, %arg91 = %alloc_17, %arg92 = %alloc_18, %arg93 = %alloc_19, %arg94 = %alloc_20, %arg95 = %alloc_21, %arg96 = %alloc_22, %arg97 = %alloc_23, %arg98 = %alloc_24, %arg99 = %alloc_25, %arg100 = %alloc_26, %arg101 = %alloc_27, %arg102 = %alloc_28, %arg103 = %alloc_29) -> (memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>) {
            ^
../test.mlir:374:13: note: see current operation: 
%0:28 = "scf.for"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) ({
^bb0(%arg0: index, %arg1: memref<1x256x1024xf32>, %arg2: memref<1x256x1024xf32>, %arg3: memref<1x256x1024xf32>, %arg4: memref<1x256x1024xf32>, %arg5: memref<1x256x1024xf32>, %arg6: memref<1x256x1024xf32>, %arg7: memref<1x256x1024xf32>, %arg8: memref<1x256x1024xf32>, %arg9: memref<1x256x1024xf32>, %arg10: memref<1x256x1024xf32>, %arg11: memref<1x256x1024xf32>, %arg12: memref<1x256x1024xf32>, %arg13: memref<1x256x1024xf32>, %arg14: memref<1x256x1024xf32>, %arg15: memref<1x256x1024xf32>, %arg16: memref<1x256x1024xf32>, %arg17: memref<1x256x1024xf32>, %arg18: memref<1x256x1024xf32>, %arg19: memref<1x256x1024xf32>, %arg20: memref<1x256x1024xf32>, %arg21: memref<1x256x1024xf32>, %arg22: memref<1x256x1024xf32>, %arg23: memref<1x256x1024xf32>, %arg24: memref<1x256x1024xf32>, %arg25: memref<1x256x1024xf32>, %arg26: memref<1x256x1024xf32>, %arg27: memref<1x256x1024xf32>, %arg28: memref<1x256x1024xf32>):
...

LLVM ERROR: operation destroyed but still has uses

cmd to reproduce:
byteir-opt --convert-func-to-gpu input.mlir

Input mlir:

func.func private @Unknown0(%arg0: memref<1x256x1024xf32>, %arg1: memref<256x1024xf16>, %arg2: memref<256x1024xf16>, %arg3: memref<256x1024xf16>, %arg4: memref<256x1024xf16>, %arg5: memref<256x1024xf16>, %arg6: memref<256x1024xf16>, %arg7: memref<256x1024xf16>, %arg8: memref<256x1024xf16>, %arg9: memref<256x1024xf16>, %arg10: memref<256x1024xf16>, %arg11: memref<256x1024xf16>, %arg12: memref<256x1024xf16>, %arg13: memref<256x1024xf16>, %arg14: memref<256x1024xf16>, %arg15: memref<256x1024xf16>, %arg16: memref<256x1024xf16>, %arg17: memref<256x1024xf16>, %arg18: memref<256x1024xf16>, %arg19: memref<256x1024xf16>, %arg20: memref<256x1024xf16>, %arg21: memref<256x1024xf16>, %arg22: memref<256x1024xf16>, %arg23: memref<256x1024xf16>, %arg24: memref<256x1024xf16>, %arg25: memref<256x1024xf16>, %arg26: memref<1x256x1xf32>, %arg27: memref<1x256x1xf32>, %arg28: memref<1024xf32>, %arg29: memref<256x1xf32>, %arg30: memref<256x1xf32>, %arg31: memref<1x256x1xf32>, %arg32: memref<1x256x1xf32>, %arg33: memref<256x1xf32>, %arg34: memref<256x1xf32>, %arg35: memref<1x256x1xf32>, %arg36: memref<1x256x1xf32>, %arg37: memref<256x1xf32>, %arg38: memref<256x1xf32>, %arg39: memref<1x256x1xf32>, %arg40: memref<1x256x1xf32>, %arg41: memref<256x1xf32>, %arg42: memref<256x1xf32>, %arg43: memref<1x256x1xf32>, %arg44: memref<1x256x1xf32>, %arg45: memref<256x1xf32>, %arg46: memref<256x1xf32>, %arg47: memref<1x256x1xf32>, %arg48: memref<1x256x1xf32>, %arg49: memref<256x1xf32>, %arg50: memref<256x1xf32>, %arg51: memref<1x256x1xf32>, %arg52: memref<1x256x1xf32>, %arg53: memref<256x1xf32>, %arg54: memref<256x1xf32>, %arg55: memref<1x256x1xf32>, %arg56: memref<1x256x1xf32>, %arg57: memref<256x1xf32>, %arg58: memref<256x1xf32>, %arg59: memref<1x256x1xf32>, %arg60: memref<1x256x1xf32>, %arg61: memref<256x1xf32>, %arg62: memref<256x1xf32>, %arg63: memref<1x256x1xf32>, %arg64: memref<1x256x1xf32>, %arg65: memref<256x1xf32>, %arg66: memref<256x1xf32>, %arg67: memref<1x256x1xf32>, %arg68: memref<1x256x1xf32>, %arg69: memref<256x1xf32>, %arg70: memref<256x1xf32>, %arg71: memref<1x256x1xf32>, %arg72: memref<1x256x1xf32>, %arg73: memref<256x1xf32>, %arg74: memref<256x1xf32>) -> (memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>) attributes {__byteir_elementwise_fusion__, __byteir_to_gpu__} {
  %cst = arith.constant 1.024000e+03 : f32
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c256 = arith.constant 256 : index
  %c1024 = arith.constant 1024 : index
  %c262144 = arith.constant 262144 : index
  %alloc = memref.alloc() : memref<1x256x1024xf32>
  %alloca = memref.alloca() : memref<1x256x1024xf32>
  %alloc_0 = memref.alloc() : memref<256x1024xf32>
  memref.copy %alloca, %alloc : memref<1x256x1024xf32> to memref<1x256x1024xf32>
  %alloc_1 = memref.alloc() : memref<256x1024xf32>
  %alloca_2 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_3 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_4 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_5 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_6 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_7 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_8 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_9 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_10 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_11 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_12 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_13 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_14 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_15 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_16 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_17 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_18 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_19 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_20 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_21 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_22 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_23 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_24 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_25 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_26 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_27 = memref.alloca() : memref<1x256x1024xf32>
  %alloca_28 = memref.alloca() : memref<1x256x1024xf32>
  %0:28 = scf.for %arg75 = %c0 to %c256 step %c1 iter_args(%arg76 = %alloca, %arg77 = %alloca_2, %arg78 = %alloca_3, %arg79 = %alloca_4, %arg80 = %alloca_5, %arg81 = %alloca_6, %arg82 = %alloca_7, %arg83 = %alloca_8, %arg84 = %alloca_9, %arg85 = %alloca_10, %arg86 = %alloca_11, %arg87 = %alloca_12, %arg88 = %alloca_13, %arg89 = %alloca_14, %arg90 = %alloca_15, %arg91 = %alloca_16, %arg92 = %alloca_17, %arg93 = %alloca_18, %arg94 = %alloca_19, %arg95 = %alloca_20, %arg96 = %alloca_21, %arg97 = %alloca_22, %arg98 = %alloca_23, %arg99 = %alloca_24, %arg100 = %alloca_25, %arg101 = %alloca_26, %arg102 = %alloca_27, %arg103 = %alloca_28) -> (memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>) {
    %alloca_39 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_40 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_41 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_42 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_43 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_44 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_45 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_46 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_47 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_48 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_49 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_50 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_51 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_52 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_53 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_54 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_55 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_56 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_57 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_58 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_59 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_60 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_61 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_62 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_63 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_64 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_65 = memref.alloca() : memref<1x256x1024xf32>
    %alloca_66 = memref.alloca() : memref<1x256x1024xf32>
    %1:28 = scf.for %arg104 = %c0 to %c1024 step %c1 iter_args(%arg105 = %alloca_39, %arg106 = %alloca_40, %arg107 = %alloca_41, %arg108 = %alloca_42, %arg109 = %alloca_43, %arg110 = %alloca_44, %arg111 = %alloca_45, %arg112 = %alloca_46, %arg113 = %alloca_47, %arg114 = %alloca_48, %arg115 = %alloca_49, %arg116 = %alloca_50, %arg117 = %alloca_51, %arg118 = %alloca_52, %arg119 = %alloca_53, %arg120 = %alloca_54, %arg121 = %alloca_55, %arg122 = %alloca_56, %arg123 = %alloca_57, %arg124 = %alloca_58, %arg125 = %alloca_59, %arg126 = %alloca_60, %arg127 = %alloca_61, %arg128 = %alloca_62, %arg129 = %alloca_63, %arg130 = %alloca_64, %arg131 = %alloca_65, %arg132 = %alloca_66) -> (memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>) {
      %2 = arith.remsi %arg75, %c256 : index
      %3 = arith.cmpi slt, %2, %c0 : index
      %4 = arith.addi %2, %c256 : index
      %5 = arith.select %3, %4, %2 : index
      %6 = arith.cmpi slt, %arg104, %c0 : index
      %7 = arith.addi %arg104, %c1024 : index
      %8 = arith.select %6, %7, %arg104 : index
      %subview = memref.subview %alloc[0, %5, %8] [1, 1, 1] [1, 1, 1] : memref<1x256x1024xf32> to memref<f32, strided<[], offset: ?>>
      %alloca_67 = memref.alloca() : memref<f32>
      %alloca_68 = memref.alloca() : memref<f32>
      %alloca_69 = memref.alloca() : memref<f32>
      %alloca_70 = memref.alloca() : memref<f32>
      %alloca_71 = memref.alloca() : memref<f32>
      %alloca_72 = memref.alloca() : memref<f32>
      %alloca_73 = memref.alloca() : memref<f32>
      %alloca_74 = memref.alloca() : memref<f32>
      %alloca_75 = memref.alloca() : memref<f32>
      %alloca_76 = memref.alloca() : memref<f32>
      %alloca_77 = memref.alloca() : memref<f32>
      %alloca_78 = memref.alloca() : memref<f32>
      %alloca_79 = memref.alloca() : memref<f32>
      %alloca_80 = memref.alloca() : memref<f32>
      %alloca_81 = memref.alloca() : memref<f32>
      %alloca_82 = memref.alloca() : memref<f32>
      %alloca_83 = memref.alloca() : memref<f32>
      %alloca_84 = memref.alloca() : memref<f32>
      %alloca_85 = memref.alloca() : memref<f32>
      %alloca_86 = memref.alloca() : memref<f32>
      %alloca_87 = memref.alloca() : memref<f32>
      %alloca_88 = memref.alloca() : memref<f32>
      %alloca_89 = memref.alloca() : memref<f32>
      %alloca_90 = memref.alloca() : memref<f32>
      %alloca_91 = memref.alloca() : memref<f32>
      %alloca_92 = memref.alloca() : memref<f32>
      %alloca_93 = memref.alloca() : memref<f32>
      %9 = memref.load %arg25[%5, %8] : memref<256x1024xf16>
      %10 = memref.load %arg28[%8] : memref<1024xf32>
      %11 = memref.load %arg0[%c0, %5, %8] : memref<1x256x1024xf32>
      %12 = memref.load %arg1[%5, %8] : memref<256x1024xf16>
      %13 = memref.load %arg2[%5, %8] : memref<256x1024xf16>
      %14 = memref.load %arg3[%5, %8] : memref<256x1024xf16>
      %15 = memref.load %arg4[%5, %8] : memref<256x1024xf16>
      %16 = memref.load %arg5[%5, %8] : memref<256x1024xf16>
      %17 = memref.load %arg6[%5, %8] : memref<256x1024xf16>
      %18 = memref.load %arg7[%5, %8] : memref<256x1024xf16>
      %19 = memref.load %arg8[%5, %8] : memref<256x1024xf16>
      %20 = memref.load %arg9[%5, %8] : memref<256x1024xf16>
      %21 = memref.load %arg10[%5, %8] : memref<256x1024xf16>
      %22 = memref.load %arg11[%5, %8] : memref<256x1024xf16>
      %23 = memref.load %arg12[%5, %8] : memref<256x1024xf16>
      %24 = memref.load %arg13[%5, %8] : memref<256x1024xf16>
      %25 = memref.load %arg14[%5, %8] : memref<256x1024xf16>
      %26 = memref.load %arg15[%5, %8] : memref<256x1024xf16>
      %27 = memref.load %arg16[%5, %8] : memref<256x1024xf16>
      %28 = memref.load %arg17[%5, %8] : memref<256x1024xf16>
      %29 = memref.load %arg18[%5, %8] : memref<256x1024xf16>
      %30 = memref.load %arg19[%5, %8] : memref<256x1024xf16>
      %31 = memref.load %arg20[%5, %8] : memref<256x1024xf16>
      %32 = memref.load %arg21[%5, %8] : memref<256x1024xf16>
      %33 = memref.load %arg22[%5, %8] : memref<256x1024xf16>
      %34 = memref.load %arg23[%5, %8] : memref<256x1024xf16>
      %35 = memref.load %arg24[%5, %8] : memref<256x1024xf16>
      %36 = memref.load %arg26[%c0, %5, %c0] : memref<1x256x1xf32>
      %37 = memref.load %arg27[%c0, %5, %c0] : memref<1x256x1xf32>
      %38 = memref.load %arg31[%c0, %5, %c0] : memref<1x256x1xf32>
      %39 = memref.load %arg32[%c0, %5, %c0] : memref<1x256x1xf32>
      %40 = memref.load %arg35[%c0, %5, %c0] : memref<1x256x1xf32>
      %41 = memref.load %arg36[%c0, %5, %c0] : memref<1x256x1xf32>
      %42 = memref.load %arg39[%c0, %5, %c0] : memref<1x256x1xf32>
      %43 = memref.load %arg40[%c0, %5, %c0] : memref<1x256x1xf32>
      %44 = memref.load %arg43[%c0, %5, %c0] : memref<1x256x1xf32>
      %45 = memref.load %arg44[%c0, %5, %c0] : memref<1x256x1xf32>
      %46 = memref.load %arg47[%c0, %5, %c0] : memref<1x256x1xf32>
      %47 = memref.load %arg48[%c0, %5, %c0] : memref<1x256x1xf32>
      %48 = memref.load %arg51[%c0, %5, %c0] : memref<1x256x1xf32>
      %49 = memref.load %arg52[%c0, %5, %c0] : memref<1x256x1xf32>
      %50 = memref.load %arg55[%c0, %5, %c0] : memref<1x256x1xf32>
      %51 = memref.load %arg56[%c0, %5, %c0] : memref<1x256x1xf32>
      %52 = memref.load %arg59[%c0, %5, %c0] : memref<1x256x1xf32>
      %53 = memref.load %arg60[%c0, %5, %c0] : memref<1x256x1xf32>
      %54 = memref.load %arg63[%c0, %5, %c0] : memref<1x256x1xf32>
      %55 = memref.load %arg64[%c0, %5, %c0] : memref<1x256x1xf32>
      %56 = memref.load %arg67[%c0, %5, %c0] : memref<1x256x1xf32>
      %57 = memref.load %arg68[%c0, %5, %c0] : memref<1x256x1xf32>
      %58 = memref.load %arg71[%c0, %5, %c0] : memref<1x256x1xf32>
      %59 = memref.load %arg72[%c0, %5, %c0] : memref<1x256x1xf32>
      %60 = memref.load %arg73[%arg75, %c0] : memref<256x1xf32>
      %61 = memref.load %arg74[%arg75, %c0] : memref<256x1xf32>
      %62 = arith.extf %34 : f16 to f32
      %63 = arith.extf %32 : f16 to f32
      %64 = arith.extf %30 : f16 to f32
      %65 = arith.extf %28 : f16 to f32
      %66 = arith.extf %26 : f16 to f32
      %67 = arith.extf %24 : f16 to f32
      %68 = arith.extf %22 : f16 to f32
      %69 = arith.extf %20 : f16 to f32
      %70 = arith.extf %18 : f16 to f32
      %71 = arith.extf %16 : f16 to f32
      %72 = arith.extf %14 : f16 to f32
      %73 = arith.extf %12 : f16 to f32
      %74 = arith.addf %11, %73 : f32
      %75 = arith.extf %13 : f16 to f32
      %76 = arith.addf %74, %75 : f32
      %77 = arith.addf %76, %72 : f32
      %78 = arith.extf %15 : f16 to f32
      %79 = arith.addf %77, %78 : f32
      %80 = arith.addf %79, %71 : f32
      %81 = arith.extf %17 : f16 to f32
      %82 = arith.addf %80, %81 : f32
      %83 = arith.addf %82, %70 : f32
      %84 = arith.extf %19 : f16 to f32
      %85 = arith.addf %83, %84 : f32
      %86 = arith.addf %85, %69 : f32
      %87 = arith.extf %21 : f16 to f32
      %88 = arith.addf %86, %87 : f32
      %89 = arith.addf %88, %68 : f32
      %90 = arith.extf %23 : f16 to f32
      %91 = arith.addf %89, %90 : f32
      %92 = arith.addf %91, %67 : f32
      %93 = arith.extf %25 : f16 to f32
      %94 = arith.addf %92, %93 : f32
      %95 = arith.addf %94, %66 : f32
      %96 = arith.extf %27 : f16 to f32
      %97 = arith.addf %95, %96 : f32
      %98 = arith.addf %97, %65 : f32
      %99 = arith.extf %29 : f16 to f32
      %100 = arith.addf %98, %99 : f32
      %101 = arith.addf %100, %64 : f32
      %102 = arith.extf %31 : f16 to f32
      %103 = arith.addf %101, %102 : f32
      %104 = arith.addf %103, %63 : f32
      %105 = arith.extf %33 : f16 to f32
      %106 = arith.addf %104, %105 : f32
      %107 = arith.addf %106, %62 : f32
      %108 = arith.extf %35 : f16 to f32
      %109 = arith.addf %107, %108 : f32
      %110 = arith.subf %109, %36 : f32
      %111 = arith.mulf %110, %37 : f32
      %112 = arith.extf %9 : f16 to f32
      %113 = arith.mulf %112, %10 : f32
      %114 = arith.mulf %113, %cst : f32
      %115 = arith.mulf %113, %111 : f32
      %116 = arith.mulf %112, %111 : f32
      %117 = arith.subf %106, %38 : f32
      %118 = arith.mulf %117, %39 : f32
      %119 = arith.subf %103, %40 : f32
      %120 = arith.mulf %119, %41 : f32
      %121 = arith.subf %100, %42 : f32
      %122 = arith.mulf %121, %43 : f32
      %123 = arith.subf %97, %44 : f32
      %124 = arith.mulf %123, %45 : f32
      %125 = arith.subf %94, %46 : f32
      %126 = arith.mulf %125, %47 : f32
      %127 = arith.subf %91, %48 : f32
      %128 = arith.mulf %127, %49 : f32
      %129 = arith.subf %88, %50 : f32
      %130 = arith.mulf %129, %51 : f32
      %131 = arith.subf %85, %52 : f32
      %132 = arith.mulf %131, %53 : f32
      %133 = arith.subf %82, %54 : f32
      %134 = arith.mulf %133, %55 : f32
      %135 = arith.subf %79, %56 : f32
      %136 = arith.mulf %135, %57 : f32
      %137 = arith.subf %76, %58 : f32
      %138 = arith.mulf %137, %59 : f32
      %139 = arith.subf %74, %60 : f32
      %140 = arith.mulf %139, %61 : f32
      memref.store %111, %alloc[%c0, %5, %8] : memref<1x256x1024xf32>
      memref.store %77, %alloca_67[] : memref<f32>
      memref.store %80, %alloca_68[] : memref<f32>
      memref.store %83, %alloca_69[] : memref<f32>
      memref.store %86, %alloca_70[] : memref<f32>
      memref.store %89, %alloca_71[] : memref<f32>
      memref.store %92, %alloca_72[] : memref<f32>
      memref.store %95, %alloca_73[] : memref<f32>
      memref.store %98, %alloca_74[] : memref<f32>
      memref.store %101, %alloca_75[] : memref<f32>
      memref.store %104, %alloca_76[] : memref<f32>
      memref.store %107, %alloca_77[] : memref<f32>
      memref.store %112, %alloca_78[] : memref<f32>
      memref.store %113, %alloca_79[] : memref<f32>
      memref.store %114, %alloca_80[] : memref<f32>
      memref.store %115, %alloca_81[] : memref<f32>
      memref.store %116, %alloca_82[] : memref<f32>
      memref.store %118, %alloca_83[] : memref<f32>
      memref.store %120, %alloca_84[] : memref<f32>
      memref.store %122, %alloca_85[] : memref<f32>
      memref.store %124, %alloca_86[] : memref<f32>
      memref.store %126, %alloca_87[] : memref<f32>
      memref.store %128, %alloca_88[] : memref<f32>
      memref.store %130, %alloca_89[] : memref<f32>
      memref.store %132, %alloca_90[] : memref<f32>
      memref.store %134, %alloca_91[] : memref<f32>
      memref.store %136, %alloca_92[] : memref<f32>
      memref.store %138, %alloca_93[] : memref<f32>
      memref.store %140, %alloc_1[%arg75, %arg104] : memref<256x1024xf32>
      memref.copy %alloca_67, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_68, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_69, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_70, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_71, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_72, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_73, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_74, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_75, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_76, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_77, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_78, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_79, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_80, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_81, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_82, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_83, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_84, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_85, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_86, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_87, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_88, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_89, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_90, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_91, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_92, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloca_93, %subview : memref<f32> to memref<f32, strided<[], offset: ?>>
      scf.yield %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc, %alloc : memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>
    }
    scf.yield %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7, %1#8, %1#9, %1#10, %1#11, %1#12, %1#13, %1#14, %1#15, %1#16, %1#17, %1#18, %1#19, %1#20, %1#21, %1#22, %1#23, %1#24, %1#25, %1#26, %1#27 : memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_29 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#1[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg69[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg70[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_29[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_30 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#2[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg65[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg66[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_30[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_31 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#3[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg61[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg62[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_31[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_32 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#4[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg57[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg58[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_32[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_33 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#5[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg53[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg54[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_33[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_34 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#6[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg49[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg50[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_34[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_35 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#7[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg45[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg46[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_35[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_36 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#8[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg41[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg42[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_36[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_37 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#9[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg37[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg38[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_37[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  %alloc_38 = memref.alloc() : memref<256x1024xf32>
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#10[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg33[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg34[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_38[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  scf.for %arg75 = %c0 to %c262144 step %c1 {
    %1 = arith.remsi %arg75, %c1024 : index
    %2 = arith.divsi %arg75, %c1024 : index
    %3 = arith.remsi %2, %c256 : index
    %4 = arith.remsi %1, %c1024 : index
    %5 = memref.load %0#11[%c0, %3, %4] : memref<1x256x1024xf32>
    %6 = memref.load %arg29[%2, %c0] : memref<256x1xf32>
    %7 = memref.load %arg30[%2, %c0] : memref<256x1xf32>
    %8 = arith.subf %5, %6 : f32
    %9 = arith.mulf %8, %7 : f32
    memref.store %9, %alloc_0[%2, %1] : memref<256x1024xf32>
  } {__byteir_coarsen_simt__, __byteir_loop_to_simt__ = "linear_id.x"}
  return %0#12, %0#0, %0#13, %0#14, %0#15, %0#16, %alloc_0, %0#17, %alloc_38, %0#18, %alloc_37, %0#19, %alloc_36, %0#20, %alloc_35, %0#21, %alloc_34, %0#22, %alloc_33, %0#23, %alloc_32, %0#24, %alloc_31, %0#25, %alloc_30, %0#26, %alloc_29, %0#27, %alloc_1 : memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>, memref<1x256x1024xf32>, memref<256x1024xf32>
}

cc: @liwenchangbdbz @yaochengji

[Compiler] invalid PTX

Hitting CUDA_ERROR_INVALID_PTX in one of the models.

Resulting ptx:

	// .globl	Unknown576
.visible .entry Unknown576(
	.param .u64 Unknown576_param_0,
	.param .u64 Unknown576_param_1,
	.param .u64 Unknown576_param_2,
	.param .u64 Unknown576_param_3,
	.param .u64 Unknown576_param_4,
	.param .u64 Unknown576_param_5,
	.param .u64 Unknown576_param_6,
	.param .u64 Unknown576_param_7,
	.param .u64 Unknown576_param_8,
	.param .u64 Unknown576_param_9,
	.param .u64 Unknown576_param_10,
	.param .u64 Unknown576_param_11,
	.param .u64 Unknown576_param_12,
	.param .u64 Unknown576_param_13,
	.param .u64 Unknown576_param_14,
	.param .u64 Unknown576_param_15,
	.param .u64 Unknown576_param_16,
	.param .u64 Unknown576_param_17
)
{
	.reg .pred 	%p<3>;
	.reg .b16 	%h<2>;
	.reg .b32 	%r<5>;
	.reg .f32 	%f<2>;
	.reg .b64 	%rd<24>;

	mov.u32 	%r1, %ctaid.x;
	mov.u32 	%r2, %ntid.x;
	mov.u32 	%r3, %tid.x;
	cvt.s64.s32 	%rd17, %r3;
	mul.wide.s32 	%rd18, %r2, %r1;
	add.s64 	%rd23, %rd18, %rd17;
	setp.gt.s64 	%p1, %rd23, 12865791;
	@%p1 bra 	$L__BB259_3;
	ld.param.u64 	%rd15, [Unknown576_param_10];
	cvta.to.global.u64 	%rd1, %rd15;
	ld.param.u64 	%rd16, [Unknown576_param_1];
	cvta.to.global.u64 	%rd2, %rd16;
	mov.u32 	%r4, %nctaid.x;
	mul.wide.s32 	%rd4, %r2, %r4;
	shl.b64 	%rd19, %rd23, 1;
	add.s64 	%rd22, %rd2, %rd19;
	shl.b64 	%rd6, %rd4, 1;
	shl.b64 	%rd20, %rd23, 2;
	add.s64 	%rd21, %rd1, %rd20;
	shl.b64 	%rd8, %rd4, 2;
$L__BB259_2:
	ld.global.nc.b16 	%h1, [%rd22];
	cvt.f32.f16 	%f1, %h1;
	st.global.f32 	[%rd21], %f1;
	add.s64 	%rd23, %rd23, %rd4;
	add.s64 	%rd22, %rd22, %rd6;
	add.s64 	%rd21, %rd21, %rd8;
	setp.lt.s64 	%p2, %rd23, 12865792;
	@%p2 bra 	$L__BB259_2;
$L__BB259_3:
	ret;

}

mhlo:

  func.func private @Unknown576(%arg0: tensor<1x256x50257xf16>) -> (tensor<256x50257xf16>, tensor<256x50257xf32>) attributes {__byteir_elementwise_fusion__} {
    %0 = mhlo.reshape %arg0 : (tensor<1x256x50257xf16>) -> tensor<256x50257xf16>
    %1 = mhlo.convert %0 : (tensor<256x50257xf16>) -> tensor<256x50257xf32>
    return %0, %1 : tensor<256x50257xf16>, tensor<256x50257xf32>
  }

Note that unit test works fine but generates a different ptx

	// .globl	Unknown0

.visible .entry Unknown0(
	.param .u64 Unknown0_param_0,
	.param .u64 Unknown0_param_1
)
{
	.reg .pred 	%p<3>;
	.reg .b16 	%rs<2>;
	.reg .b32 	%r<5>;
	.reg .f32 	%f<2>;
	.reg .b64 	%rd<24>;

	mov.u32 	%r1, %ctaid.x;
	mov.u32 	%r2, %ntid.x;
	mov.u32 	%r3, %tid.x;
	cvt.s64.s32 	%rd17, %r3;
	mul.wide.s32 	%rd18, %r2, %r1;
	add.s64 	%rd23, %rd18, %rd17;
	setp.gt.s64 	%p1, %rd23, 12865791;
	@%p1 bra 	$L__BB0_3;
	ld.param.u64 	%rd15, [Unknown0_param_0];
	ld.param.u64 	%rd16, [Unknown0_param_1];
	cvta.to.global.u64 	%rd1, %rd16;
	cvta.to.global.u64 	%rd2, %rd15;
	mov.u32 	%r4, %nctaid.x;
	mul.wide.s32 	%rd4, %r2, %r4;
	shl.b64 	%rd19, %rd23, 1;
	add.s64 	%rd22, %rd2, %rd19;
	shl.b64 	%rd6, %rd4, 1;
	shl.b64 	%rd20, %rd23, 2;
	add.s64 	%rd21, %rd1, %rd20;
	shl.b64 	%rd8, %rd4, 2;
$L__BB0_2:
	ld.global.nc.u16 	%rs1, [%rd22];
	cvt.f32.f16 	%f1, %rs1;
	st.global.f32 	[%rd21], %f1;
	add.s64 	%rd23, %rd23, %rd4;
	add.s64 	%rd22, %rd22, %rd6;
	add.s64 	%rd21, %rd21, %rd8;
	setp.lt.s64 	%p2, %rd23, 12865792;
	@%p2 bra 	$L__BB0_2;
$L__BB0_3:
	ret;

}

contact for full debug log..

cc: @yaochengji @liwenchangbdbz

Failed to build onnx-frontend

Hi, thanks for this great project!

When building the onnx-frontend by following the instruction, it failed with

/home/shshao/byteir/frontends/onnx-frontend/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.cc:6067:15: error: no declaration matches ‘mlir::LogicalResult mlir::mhlo::WhileOp::fold(mlir::mhlo::WhileOp::FoldAdaptor, llvm::SmallVectorImpl<mlir::OpFoldResult>&)’
 6067 | LogicalResult WhileOp::fold(FoldAdaptor /*adaptor*/,
      |               ^~~~~~~
In file included from /home/shshao/byteir/frontends/onnx-frontend/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.h:99,
                 from /home/shshao/byteir/frontends/onnx-frontend/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.cc:18:
/home/shshao/byteir/frontends/onnx-frontend/build/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.h.inc:16190:25: note: candidate is: ‘mlir::LogicalResult mlir::mhlo::WhileOp::fold(llvm::ArrayRef<mlir::Attribute>, llvm::SmallVectorImpl<mlir::OpFoldResult>&)’
16190 |   ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
      |                         ^~~~
In file included from /home/shshao/byteir/frontends/onnx-frontend/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.h:99,
                 from /home/shshao/byteir/frontends/onnx-frontend/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.cc:18:
/home/shshao/byteir/frontends/onnx-frontend/build/third_party/onnx-mlir/third_party/mlir-hlo/mhlo/IR/hlo_ops.h.inc:16158:7: note: ‘class mlir::mhlo::WhileOp’ defined here
16158 | class WhileOp : public ::mlir::Op<WhileOp, ::mlir::OpTrait::NRegions<2>::Impl, ::mlir::OpTrait::VariadicResults, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::SingleBlockImplicitTerminator<ReturnOp>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::HasRecursiveMemoryEffects, ::mlir::InferTypeOpInterface::Trait, ::mlir::OpAsmOpInterface::Trait> {
      |       ^~~~~~~
[219/555] Building CXX object third_party/onnx-mlir...Dialect/ONNX/CMakeFiles/OMONNXOps.dir/ONNXOps.cpp.o
ninja: build stopped: subcommand failed.

It seems the generated mlir-hlo/mhlo/IR/hlo_ops.h.inc is inconsistent with mlir-hlo/mhlo/IR/hlo_ops.cc, do you know why? I have confirmed the llvm commit is 9acc2f37bdfce08ca0c2faec03392db10d1bb7a9 and there is -DLLVM_ENABLE_RTTI=ON when building it:

cmake -G Ninja ../llvm    -DLLVM_ENABLE_PROJECTS=mlir    -DLLVM_TARGETS_TO_BUILD="host"    -DCMAKE_BUILD_TYPE=Release    -DLLVM_ENABLE_ASSERTIONS=ON    -DLLVM_ENABLE_RTTI=ON

how to define a custom op in torch mlir

Can you give an example demo of convert a fx graph to torch mlir? the fx graph include at least one custom op like "dynamic_partition" "dynamic_stitch" "dynamic_mask_stitch".

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.