# Description:
#   Components that implement GPU autotuning.

load(
    "@local_config_rocm//rocm:build_defs.bzl",
    "if_rocm_is_configured",
)
load("//xla:xla.default.bzl", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
load(
    "//xla/tsl/platform:build_config.bzl",
    "tf_proto_library",
)
load(
    "//xla/tsl/platform/default:cuda_build_defs.bzl",
    "if_cuda_is_configured",
)

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [":friends"],
    licenses = ["notice"],
)

package_group(
    name = "friends",
    includes = [
        "//xla:friends",
    ],
)

cc_library(
    name = "gemm_fusion_autotuner_cuda",
    srcs = [
        "gemm_fusion_autotuner.h",
        "gemm_fusion_autotuner_cuda.cc",
    ],
    tags = [
        "cuda-only",
        "gpu",
    ],
    deps = [
        ":autotuner_compile_util",
        ":autotuner_util",
        ":redzone_buffers",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/pjrt/distributed:key_value_store_interface",
        "//xla/service:algorithm_util",
        "//xla/service:executable",
        "//xla/service:shaped_buffer",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/transforms:cudnn_fusion_compiler",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:semantic_version",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:env",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@local_config_cuda//cuda:cuda_headers",
    ],
)

cc_library(
    name = "gemm_fusion_autotuner_rocm",
    srcs = [
        "gemm_fusion_autotuner.h",
        "gemm_fusion_autotuner_rocm.cc",
    ],
    tags = [
        "gpu",
        "rocm-only",
    ],
    deps = [
        ":autotuner_compile_util",
        ":autotuner_util",
        ":redzone_buffers",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/pjrt/distributed:key_value_store_interface",
        "//xla/service:executable",
        "//xla/service:shaped_buffer",
        "//xla/service/gpu:matmul_utils",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:semantic_version",
        "//xla/stream_executor/rocm:rocblas_plugin",
        "//xla/tsl/platform:env",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@local_config_rocm//rocm:rocm_headers",
    ],
)

cc_library(
    name = "gemm_fusion_autotuner",
    srcs = [
        "gemm_fusion_autotuner.cc",
    ],
    hdrs = ["gemm_fusion_autotuner.h"],
    tags = ["gpu"],
    deps = if_cuda_is_configured([":gemm_fusion_autotuner_cuda"]) + if_rocm_is_configured([
        ":gemm_fusion_autotuner_rocm",
    ]) + [
        ":autotuner_compile_util",
        ":autotuner_status_key",
        ":autotuner_util",
        ":dot_search_space",
        ":redzone_buffers",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/runtime:buffer_comparator",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/hlo/pass:hlo_pass_pipeline",
        "//xla/hlo/transforms/simplifiers:float_normalization",
        "//xla/hlo/utils:hlo_query",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/pjrt/distributed:key_value_store_interface",
        "//xla/service:algorithm_util",
        "//xla/service:call_inliner",
        "//xla/service:dump",
        "//xla/service:executable",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_graph_dumper",
        "//xla/service:hlo_module_config",
        "//xla/service:shaped_buffer",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_float_support",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:matmul_indexing_utils",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:split_k_gemm_rewriter",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/kernels:custom_kernel",
        "//xla/service/gpu/kernels:custom_kernel_fusion",
        "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
        "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
        "//xla/service/gpu/transforms:dot_algorithm_rewriter",
        "//xla/service/gpu/transforms:fusion_wrapper",
        "//xla/service/gpu/transforms:gemm_rewriter",
        "//xla/service/gpu/transforms:nest_gemm_fusion",
        "//xla/service/gpu/transforms:priority_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:semantic_version",
        "//xla/stream_executor:stream",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/cuda:ptx_compiler_helpers",
        "//xla/stream_executor/gpu:redzone_allocator",
        "//xla/stream_executor/integrations:tf_allocator_adapter",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/lib/core:bits",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:status",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_utils",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:overload",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:protobuf",
        "@local_tsl//tsl/profiler/lib:scoped_annotation",
        "@local_tsl//tsl/profiler/lib:traceme",
    ],
)

xla_test(
    name = "gemm_fusion_autotuner_test",
    timeout = "long",
    srcs = ["gemm_fusion_autotuner_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":autotuner_util",
        ":gemm_fusion_autotuner",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass_pipeline",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/pjrt/distributed:key_value_store_interface",
        "//xla/service:call_inliner",
        "//xla/service:dump",
        "//xla/service:executable",
        "//xla/service:hlo_module_config",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu/transforms:gemm_fusion",
        "//xla/service/gpu/transforms:gemm_rewriter",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:semantic_version",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:hlo_test_base",
        "//xla/tests:test_utils",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/time",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:platform_port",
    ],
)

cc_library(
    name = "dot_search_space",
    srcs = ["dot_search_space.cc"],
    hdrs = ["dot_search_space.h"],
    tags = ["gpu"],
    deps = [
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service/gpu:matmul_utils",
        "//xla/stream_executor:device_description",
        "//xla/tsl/lib/core:bits",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/strings:str_format",
        "@llvm-project//llvm:Support",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

xla_cc_test(
    name = "dot_search_space_test",
    srcs = ["dot_search_space_test.cc"],
    tags = ["gpu"],
    deps = [
        ":dot_search_space",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service/gpu:matmul_utils",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
    ],
)

xla_test(
    name = "autotuner_pass_test",
    srcs = ["autotuner_pass_test.cc"],
    backends = ["gpu"],
    tags = ["cuda-only"],
    deps = [
        ":autotuner_pass",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/backends/gpu/autotuner:cublas",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:platform_util",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:platform_manager",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "gemm_algorithm_picker",
    srcs = ["gemm_algorithm_picker.cc"],
    hdrs = ["gemm_algorithm_picker.h"],
    tags = ["gpu"],
    deps = [
        ":autotuner_util",
        ":redzone_buffers",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/runtime:buffer_comparator",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:stream_executor_util",
        "//xla/stream_executor:blas",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/gpu:redzone_allocator",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:logging",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_utils",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:overload",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/profiler/lib:scoped_annotation",
    ],
)

cc_library(
    name = "autotuner_status_key",
    srcs = ["autotuner_status_key.cc"],
    hdrs = ["autotuner_status_key.h"],
    compatible_with = get_compatible_with_portable(),
    deps = ["@com_google_absl//absl/strings"],
)

cc_library(
    name = "autotuner_util",
    srcs = ["autotuner_util.cc"],
    hdrs = ["autotuner_util.h"],
    compatible_with = get_compatible_with_portable(),
    deps = [
        ":autotuner_status_key",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:dump",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:logging",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/time",
        "@llvm-project//llvm:Support",
        "@local_tsl//tsl/platform:base64",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

# We need a separate target, as runtime executable cannot depend on compilation
# pipeline.
cc_library(
    name = "autotuner_compile_util",
    srcs = ["autotuner_compile_util.cc"],
    hdrs = ["autotuner_compile_util.h"],
    tags = ["gpu"],
    deps = [
        ":autotuner_util",
        "//xla:executable_run_options",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:compiler",
        "//xla/service:executable",
        "//xla/service:maybe_owning_device_memory",
        "//xla/service:shaped_buffer",
        "//xla/service/gpu:gpu_executable_run_options",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream",
        "//xla/stream_executor/gpu:redzone_allocator",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/profiler/lib:traceme",
    ],
)

cc_library(
    name = "redzone_buffers",
    srcs = ["redzone_buffers.cc"],
    hdrs = ["redzone_buffers.h"],
    deps = [
        "//xla:executable_run_options",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:executable",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream",
        "//xla/stream_executor/gpu:redzone_allocator",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/profiler/lib:traceme",
    ],
)

xla_test(
    name = "redzone_buffers_test",
    srcs = ["redzone_buffers_test.cc"],
    backends = ["gpu"],
    deps = [
        ":redzone_buffers",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:platform_util",
        "//xla/stream_executor:platform",
        "//xla/stream_executor:stream",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/tests:hlo_test_base",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
    ],
)

xla_test(
    name = "gemm_algorithm_picker_test",
    srcs = ["gemm_algorithm_picker_test.cc"],
    backends = [
        "v100",
        "amdgpu_any",
    ],
    deps = [
        ":autotuner_util",
        ":gemm_algorithm_picker",
        "//xla:autotune_results_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu/transforms:gemm_rewriter",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:platform",
        "//xla/stream_executor:semantic_version",
        "//xla/tests:hlo_test_base",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "//xla/tsl/platform:test_main",
        "//xla/tsl/protobuf:dnn_proto_cc",
        "@com_google_absl//absl/functional:overload",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/strings:string_view",
    ],
)

cc_library(
    name = "conv_algorithm_picker",
    srcs = ["conv_algorithm_picker.cc"],
    hdrs = ["conv_algorithm_picker.h"],
    tags = ["gpu"],
    deps = [
        ":autotuner_compile_util",
        ":autotuner_util",
        ":gpu_autotuning_proto_cc",
        ":redzone_buffers",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:debug_options_flags",
        "//xla:literal_util",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/runtime:buffer_comparator",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/service:executable",
        "//xla/service:hlo_module_config",
        "//xla/service:slow_operation_alarm",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:gpu_asm_opts_util",
        "//xla/service/gpu:gpu_conv_runner",
        "//xla/service/gpu:hlo_algorithm_denylist",
        "//xla/service/gpu:stream_executor_util",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:dnn",
        "//xla/stream_executor:lazy_op_runner",
        "//xla/stream_executor:numeric_options",
        "//xla/stream_executor:platform",
        "//xla/stream_executor:scratch_allocator",
        "//xla/stream_executor:stream",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/cuda:cuda_platform_id",
        "//xla/stream_executor/gpu:redzone_allocator",
        "//xla/stream_executor/rocm:rocm_platform_id",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:logging",
        "//xla/tsl/platform:status",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/protobuf:dnn_proto_cc",
        "//xla/tsl/util:env_var",
        "//xla/tsl/util/proto:proto_utils",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/platform:numbers",
    ],
)

xla_test(
    name = "conv_algorithm_picker_test",
    srcs = ["conv_algorithm_picker_test.cc"],
    backends = [
        "v100",
        "amdgpu_any",
    ],
    tags = [
        "cuda-only",
        "noasan",
        "nomsan",
    ],
    deps = [
        ":autotuner_util",
        ":conv_algorithm_picker",
        "//xla:autotune_results_proto_cc",
        "//xla:debug_options_flags",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/hlo/transforms/simplifiers:tuple_simplifier",
        "//xla/service:pattern_matcher",
        "//xla/service:platform_util",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/transforms:conv_rewriter",
        "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:dnn",
        "//xla/stream_executor:platform",
        "//xla/tests:hlo_test_base",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "//xla/tsl/platform:test_main",
        "@com_google_absl//absl/strings:string_view",
    ],
)

cc_library(
    name = "custom_kernel_fusion_autotuner",
    srcs = ["custom_kernel_fusion_autotuner.cc"],
    hdrs = ["custom_kernel_fusion_autotuner.h"],
    tags = ["gpu"],
    deps = [
        ":autotuner_compile_util",
        ":autotuner_util",
        ":redzone_buffers",
        "//xla:autotuning_proto_cc",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/service:executable",
        "//xla/service:shaped_buffer",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/kernels:custom_kernel",
        "//xla/service/gpu/kernels:custom_kernel_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/time",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "custom_kernel_fusion_autotuner_test",
    srcs = ["custom_kernel_fusion_autotuner_test.cc"],
    backends = [
        "gpu",
    ],
    tags = ["cuda-only"],
    deps = [
        ":autotuner_util",
        ":custom_kernel_fusion_autotuner",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass_pipeline",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/platform:test",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/platform:path",
    ],
)

tf_proto_library(
    name = "gpu_autotuning_proto",
    srcs = ["gpu_autotuning.proto"],
    protodeps = [
        "//xla/service/gpu:backend_configs",
        "//xla:xla_data_proto",
        "//xla/service:hlo_proto",
        "//xla:autotuning_proto",
    ],
)

xla_cc_test(
    name = "autotuner_util_test",
    srcs = ["autotuner_util_test.cc"],
    data = [
        "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb",
        "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb",
        "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb",
    ],
    tags = [
        "gpu",
    ],
    deps = [
        ":autotuner_status_key",
        ":autotuner_util",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:dump",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:platform",
        "//xla/stream_executor:platform_manager",
        "//xla/stream_executor/host:host_platform",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:logging",
        "//xla/tsl/platform:status",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/hash:hash_testing",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_googletest//:gtest",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

cc_library(
    name = "autotuner_pass",
    srcs = [
        "autotuner_pass.cc",
    ],
    hdrs = ["autotuner_pass.h"],
    tags = ["gpu"],
    deps = [
        "//xla/backends/autotuner",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/backends/autotuner:profiler",
        "//xla/backends/gpu/autotuner:gpu_profiler",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass",
        "//xla/stream_executor:platform",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
    ],
)
