load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("//xla:xla.default.bzl", "xla_cc_test")
load(
    "//xla/stream_executor:build_defs.bzl",
    "if_cuda_or_rocm_is_configured",
    "if_gpu_is_configured",
)
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.bzl", "if_google")
load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [":friends"],
    licenses = ["notice"],
)

package_group(
    name = "friends",
    includes = [
        "//xla:friends",
    ],
)

cc_library(
    name = "fusion",
    srcs = ["fusion.cc"],
    hdrs = ["fusion.h"],
    visibility = [
        "//xla/backends/gpu/codegen:__subpackages__",
        "//xla/service/gpu:__subpackages__",
    ],
    deps = [
        ":fusion_emitter",
        ":fusion_emitter_legacy_matmul",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla/backends/gpu/codegen:fusion_emitter",
        "//xla/backends/gpu/runtime:kernel_thunk",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/codegen/emitters:kernel_arguments",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_constants",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Support",
    ],
)

xla_cc_test(
    name = "fusion_test",
    srcs = ["fusion_test.cc"],
    tags = ["gpu"],
    deps = [
        ":fusion",
        "//xla/backends/gpu/codegen:fusion_emitter",
        "//xla/backends/gpu/codegen:fusions",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",
        "@com_google_absl//absl/status",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:status_matchers",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "emitter_helpers",
    srcs = ["emitter_helpers.cc"],
    hdrs = [
        "emitter_helpers.h",
    ],
    deps = [
        "//xla:literal",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/hlo/ir:hlo",
        "//xla/mlir_hlo",
        "//xla/mlir_hlo:map_mhlo_to_scalar_op",
        "//xla/mlir_hlo:transformation_helpers",
        "//xla/service/gpu:target_util",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "//xla/tsl/platform:status",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:statusor",
        "@triton//:TritonDialects",
    ],
)

cc_library(
    name = "compilation_pipeline",
    srcs = if_gpu_is_configured(
        [],
        ["compilation_pipeline_stub.cc"],
    ) + if_cuda_is_configured([
        "compilation_pipeline_cuda.cc",
    ]) + if_rocm_is_configured([
        "compilation_pipeline_rocm.cc",
    ]),
    hdrs = ["compilation_pipeline.h"],
    deps = [
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@llvm-project//mlir:Pass",
    ] + if_gpu_is_configured([
        "//xla/backends/gpu/codegen/triton/transforms:passes",
        "@com_google_absl//absl/strings:str_format",
        "@llvm-project//mlir:ArithToLLVM",
        "@llvm-project//mlir:ControlFlowToLLVM",
        "@llvm-project//mlir:IndexToLLVM",
        "@llvm-project//mlir:SCFToControlFlow",
        "@llvm-project//mlir:Transforms",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:matmul_utils",
        "//xla/stream_executor:device_description",
        "@triton//:TritonDialects",
        "@triton//:TritonGPUToLLVM",
        "@triton//:TritonGPUTransforms",
        "@triton//:TritonLLVMIR",
        "@triton//:TritonNvidiaGPUTransforms",
        "@triton//:TritonToTritonGPU",
        "@triton//:TritonToTritonGPUPasses",
        "@triton//:TritonTransforms",
        "@triton//:WarpSpecialization",
    ]) + if_cuda_is_configured([
        "//xla/service/gpu/llvm_gpu_backend:nvptx_backend",
        "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path",
        "@triton//third_party/nvidia:NVGPUToLLVM",
        "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
    ]) + if_rocm_is_configured([
        "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend",
        "//xla/tsl/platform:rocm_rocdl_path",
        "@triton//third_party/amd:TritonAMDGPUToLLVM",
        "@triton//third_party/amd:TritonAMDGPUTransforms",
    ]),
)

cc_library(
    name = "fusion_emitter",
    # Using if_cuda_or_rocm_is_configured guard to prevent sycl target build / link errors.
    srcs = if_cuda_or_rocm_is_configured(
        ["fusion_emitter.cc"],
        ["fusion_emitter_stub.cc"],
    ),
    hdrs = ["fusion_emitter.h"],
    deps = [
        ":compilation_pipeline",
        ":dot_algorithms",
        ":emitter_helpers",
        ":fusion_emitter_legacy_matmul",
        ":support",
        ":tma_utils",
        "//xla:autotuning_proto_cc",
        "//xla:permutation_util",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
        "//xla/backends/gpu/codegen/emitters/transforms:passes",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/backends/gpu/codegen/triton/transforms:passes",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/codegen/emitters:elemental_hlo_to_mlir",
        "//xla/codegen/emitters/ir:xla",
        "//xla/codegen/emitters/transforms:passes",
        "//xla/hlo/analysis:indexing_analysis",
        "//xla/hlo/builder:xla_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/mlir_hlo",
        "//xla/service:dump",
        "//xla/service:hlo_module_config",
        "//xla/service:instruction_fusion",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu/model:symbolic_tile_analysis",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/gpu/model:triton_emitter_constraints",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Linker",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:AffineDialect",
        "@llvm-project//mlir:AffineToStandard",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:ArithToLLVM",
        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
        "@llvm-project//mlir:ControlFlowToLLVM",
        "@llvm-project//mlir:ExecutionEngineUtils",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:FuncExtensions",
        "@llvm-project//mlir:FunctionInterfaces",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:IndexToLLVM",
        "@llvm-project//mlir:LLVMDialect",
        "@llvm-project//mlir:LLVMIRTransforms",
        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
        "@llvm-project//mlir:NVVMDialect",
        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:SCFToControlFlow",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
        "@llvm-project//mlir:ToLLVMIRTranslation",
        "@llvm-project//mlir:Transforms",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:statusor",
        "@triton//:TritonDialects",
        "@triton//:TritonTransforms",
    ] + if_gpu_is_configured([
        "@triton//:TritonNvidiaGPUTransforms",
        "@triton//:TritonGPUToLLVM",
        "@triton//:TritonToTritonGPU",
        "@triton//:TritonGPUTransforms",
        "@triton//:TritonLLVMIR",
    ]) + if_cuda_is_configured([
        "@triton//third_party/nvidia:NVGPUToLLVM",
        "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path",
        "//xla/service/gpu/llvm_gpu_backend:nvptx_backend",
        "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
    ]) + if_rocm_is_configured([
        "//xla/tsl/platform:rocm_rocdl_path",
        "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend",
        "@triton//third_party/amd:TritonAMDGPUToLLVM",
        "@triton//third_party/amd:TritonAMDGPUTransforms",
    ]),
)

cc_library(
    name = "fusion_emitter_legacy_matmul",
    srcs = if_gpu_is_configured(
        ["fusion_emitter_legacy_matmul.cc"],
        ["fusion_emitter_legacy_matmul_stub.cc"],
    ),
    hdrs = ["fusion_emitter_legacy_matmul.h"],
    deps = [
        ":dot_algorithms",
        ":emitter_helpers",
        "//xla:autotuning_proto_cc",
        "//xla:comparison_util",
        "//xla:literal",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_query",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/mlir_hlo",
        "//xla/mlir_hlo:map_mhlo_to_scalar_op",
        "//xla/mlir_hlo:transformation_helpers",
        "//xla/service:algorithm_util",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:matmul_indexing_utils",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu:triton_tiling_propagation",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:cord",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:FunctionInterfaces",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:Support",
        "@triton//:TritonDialects",
    ],
)

cc_library(
    name = "dot_algorithms",
    srcs = ["dot_algorithms.cc"],
    hdrs = ["dot_algorithms.h"],
    deps = [
        ":emitter_helpers",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:algorithm_util",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:tensor_float_32_hdr_lib",
        "@triton//:TritonDialects",
    ],
)

cc_library(
    name = "fusion_emitter_stub_for_testing",
    srcs = [
        "compilation_pipeline_stub.cc",
        "fusion_emitter_legacy_matmul_stub.cc",
        "fusion_emitter_stub.cc",
    ],
    hdrs = [
        "compilation_pipeline.h",
        "fusion_emitter.h",
        "fusion_emitter_legacy_matmul.h",
    ],
    deps = [
        "//xla:autotuning_proto_cc",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "//xla/stream_executor/gpu:tma_metadata",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:FunctionInterfaces",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@triton//:TritonDialects",
    ],
)

xla_cc_test(
    name = "fusion_emitter_stub_test",
    srcs = ["fusion_emitter_stub_test.cc"],
    deps = [
        ":fusion_emitter_stub_for_testing",
        "//xla:literal",
        "//xla:literal_util",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
    ],
)

xla_cc_test(
    name = "fusion_emitter_deviceless_test",
    srcs = ["fusion_emitter_deviceless_test.cc"],
    tags = ["no_oss"],  # Doesn't pass in OSS when building with the `fusion_emitter_stub`.
    deps = [
        ":fusion_emitter",
        "//xla:xla_proto_cc",
        "//xla/codegen:emitter_loc_op_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:xla_internal_test_main",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
    ],
)

xla_test(
    name = "fusion_emitter_device_legacy_test",
    size = "large",
    srcs = if_gpu_is_configured(["fusion_emitter_device_legacy_test.cc"]),
    # TODO(b/372714955): Fix the memory leak!
    backend_args = if_google(
        {
            "h100": ["--heap_check="],
            "a100": ["--heap_check="],
        },
        {},
    ),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 20,
    tags = [
        "large",
        "no_mac",
    ],
    deps = [
        ":fusion_emitter",
        ":test_utils",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "fusion_emitter_device_legacy_port_test",
    srcs = if_gpu_is_configured(["fusion_emitter_device_legacy_port_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = [
        "no_mac",
    ],
    deps = [
        ":fusion_emitter",
        ":test_utils",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/service/gpu/transforms:nest_gemm_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "fusion_emitter_int4_device_test",
    size = "large",
    srcs = if_gpu_is_configured(["fusion_emitter_int4_device_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 10,
    tags = [
        "large",
        "no_mac",
    ],
    deps = [
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/service/gpu/transforms:nest_gemm_fusion",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "fusion_emitter_legacy_int4_device_test",
    size = "large",
    srcs = if_gpu_is_configured(["fusion_emitter_legacy_int4_device_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 10,
    tags = [
        "large",
        "no_mac",
    ],
    deps = [
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "dot_algorithms_test",
    srcs = if_gpu_is_configured(["dot_algorithms_test.cc"]),
    backend_args = if_google(
        {
            "h100": ["--heap_check="],
            "a100": ["--heap_check="],
        },
        {},
    ),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    env = {
        "CUBLAS_EMULATE_SINGLE_PRECISION": "1",  # Trigger single precision emulation (F32_F32_F32) with BF16x9 cublas algorithm. It was introduced in cublas 12.9.
        "CUBLAS_EMULATION_STRATEGY": "performant",  # Trigger single precision emulation (F32_F32_F32) with BF16x9 cublas algorithm. It was introduced in cublas 12.9.
    },
    shard_count = 30,
    tags = [
        "no_mac",
    ],
    deps = [
        ":kernel_name_tracer",
        ":test_utils",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:literal",
        "//xla:literal_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:dump",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:test_utils",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:Support",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "fusion_emitter_device_test",
    srcs = if_gpu_is_configured(["fusion_emitter_device_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 5,
    tags = [
        "no_mac",
    ],
    deps = [
        ":fusion_emitter",
        ":support",
        ":test_utils",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:literal",
        "//xla:literal_util",
        "//xla:shape_util",
        "//xla:types",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:algorithm_util",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:test_utils",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@eigen_archive//:eigen3",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:path",
    ],
)

cc_library(
    name = "kernel_name_tracer_cuda",
    testonly = True,
    srcs = if_cuda(["kernel_name_tracer_cuda.cc"]),
    hdrs = ["kernel_name_tracer.h"],
    tags = ["manual"],  # Need to exclude this from wildcard builds
    deps = [
        "//xla/backends/profiler/gpu:cupti_collector",
        "//xla/backends/profiler/gpu:cupti_tracer",
        "//xla/tsl/profiler/utils:time_utils",
        "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc",
    ],
)

cc_library(
    name = "kernel_name_tracer_noop",
    testonly = True,
    srcs = ["kernel_name_tracer_noop.cc"],
    hdrs = ["kernel_name_tracer.h"],
    tags = ["manual"],  # Need to exclude this from wildcard builds
)

cc_library(
    name = "kernel_name_tracer",
    testonly = True,
    hdrs = ["kernel_name_tracer.h"],
    deps = if_cuda(
        [":kernel_name_tracer_cuda"],
        [":kernel_name_tracer_noop"],
    ),
)

cc_library(
    name = "test_utils",
    testonly = True,
    srcs = ["test_utils.cc"],
    hdrs = ["test_utils.h"],
    deps = [
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass_pipeline",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/hlo/transforms/simplifiers:float_normalization",
        "//xla/hlo/utils:hlo_query",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:gpu_float_support",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/service/gpu/transforms:nest_gemm_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:hlo_test_base",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_for_library",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

xla_test(
    name = "fusion_emitter_large_test",
    size = "large",
    srcs = if_gpu_is_configured(["fusion_emitter_large_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = [
        "large",
        "no_mac",
        "no_oss",  # requires-mem:16g tag doesn't work in open source
        "nozapfhahn",  # Times out under coverage
    ] + if_google([
        "requires-mem:16g",
    ]),
    deps = [
        "//xla:error_spec",
        "//xla:xla_proto_cc",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest",
    ],
)

xla_test(
    name = "fusion_emitter_parametrized_test",
    srcs = if_gpu_is_configured(["fusion_emitter_parametrized_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = ["no_mac"],
    deps = [
        ":support",
        ":test_utils",
        "//xla:comparison_util",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
    ],
)

xla_test(
    name = "fusion_emitter_parametrized_legacy_test",
    srcs = if_gpu_is_configured(["fusion_emitter_parametrized_legacy_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 10,
    tags = ["no_mac"],
    deps = [
        ":support",
        ":test_utils",
        "//xla:comparison_util",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
    ],
)

cc_library(
    name = "support",
    srcs = [
        "support.cc",
        "support_legacy.cc",
    ],
    hdrs = [
        "support.h",
        "support_legacy.h",
    ],
    deps = [
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:algorithm_util",
        "//xla/service:instruction_fusion",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:matmul_indexing_utils",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:overload",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@local_tsl//tsl/platform:tensor_float_32_utils",
    ],
)

xla_cc_test(
    name = "support_test",
    srcs = ["support_test.cc"],
    shard_count = 25,
    # TODO(b/353912594): this test does not need to run on GPU, but it is broken on CPU in OSS.
    # Force it to run on GPU temporarily in order to get important OSS coverage.
    tags = ["gpu"],
    deps = [
        ":fusion_emitter",
        ":support",
        ":test_utils",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

xla_test(
    name = "support_legacy_test",
    srcs = if_gpu_is_configured(["support_legacy_test.cc"]),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = ["no_mac"],
    deps = [
        ":fusion_emitter",
        ":support",
        ":test_utils",
        "//xla:error_spec",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
        "//xla/stream_executor:device_description",
        "//xla/tsl/lib/core:status_test_util",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:status_matchers",
        "@local_tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "tma_utils",
    srcs = ["tma_utils.cc"],
    hdrs = ["tma_utils.h"],
    deps = [
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@llvm-project//llvm:Support",
    ],
)

xla_cc_test(
    name = "tma_utils_test",
    srcs = ["tma_utils_test.cc"],
    deps = [
        ":tma_utils",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
    ],
)
