load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//xla:xla.default.bzl", "xla_cc_binary")
load("//xla/tests:build_defs.bzl", "xla_test")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [
        "//xla/backends/autotuner:__subpackages__",
        "//xla/service/gpu:__subpackages__",
    ],
    licenses = ["notice"],
)

cc_library(
    name = "gpu_codegen_backend",
    hdrs = ["gpu_codegen_backend.h"],
    deps = [
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/service:compiler",
        "//xla/service:executable",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
    ],
)

cc_library(
    name = "block_level_emitter",
    srcs = ["block_level_emitter.cc"],
    hdrs = ["block_level_emitter.h"],
    tags = ["gpu"],
    deps = [
        ":gpu_codegen_backend",
        "//xla:autotuning_proto_cc",
        "//xla:shape_util",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/service:compiler",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/base:nullability",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
    ],
)

xla_test(
    name = "block_level_emitter_test",
    srcs = ["block_level_emitter_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":block_level_emitter",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:platform_util",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_matchers",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "cublas",
    srcs = ["cublas.cc"],
    hdrs = ["cublas.h"],
    deps = [
        ":gpu_codegen_backend",
        "//xla:autotuning_proto_cc",
        "//xla:shape_util",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:compiler",
        "//xla/service:hlo_cost_analysis",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu/autotuning:redzone_buffers",
        "//xla/service/gpu/transforms:dot_algorithm_rewriter",
        "//xla/service/gpu/transforms:gemm_rewriter",
        "//xla/service/gpu/transforms:priority_fusion",
        "//xla/stream_executor:blas",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:semantic_version",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/gpu:redzone_allocator",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/lib/gtl:iterator_range",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
    ] + if_cuda([
        "//xla/stream_executor/cuda:repeat_buffer_kernel_cuda",
    ]),
)

xla_test(
    name = "cublas_test",
    srcs = ["cublas_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":cublas",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:compiler",
        "//xla/service:executable",
        "//xla/service:platform_util",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:blas",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_matchers",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "cublaslt",
    srcs = ["cublaslt.cc"],
    hdrs = ["cublaslt.h"],
    deps = [
        ":gpu_codegen_backend",
        "//xla:autotuning_proto_cc",
        "//xla:util",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/service:compiler",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu/autotuning:redzone_buffers",
        "//xla/stream_executor:blas",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/stream_executor/gpu:gpu_blas_lt",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
    ],
)

xla_test(
    name = "cublaslt_test",
    srcs = ["cublaslt_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":cublaslt",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:compiler",
        "//xla/service:platform_util",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:blas",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "cudnn",
    srcs = ["cudnn.cc"],
    hdrs = ["cudnn.h"],
    tags = [
        "cuda-only",
        "gpu",
    ],
    deps = [
        ":gpu_codegen_backend",
        "//xla:autotuning_proto_cc",
        "//xla:util",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:algorithm_util",
        "//xla/service:compiler",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:gpu_conv_runner",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/transforms:cudnn_fusion_compiler",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:dnn",
        "//xla/stream_executor:numeric_options",
        "//xla/stream_executor:stream",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/protobuf:dnn_proto_cc",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
    ],
)

xla_test(
    name = "cudnn_test",
    srcs = ["cudnn_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":cudnn",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:compiler",
        "//xla/service:platform_util",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/protobuf:dnn_proto_cc",
        "//xla/tsl/util/proto:proto_matchers",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "custom_kernel",
    srcs = ["custom_kernel.cc"],
    hdrs = ["custom_kernel.h"],
    deps = [
        ":gpu_codegen_backend",
        "//xla:autotune_results_proto_cc",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/service:compiler",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/kernels:custom_kernel",
        "//xla/service/gpu/kernels:custom_kernel_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
    ],
)

xla_test(
    name = "custom_kernel_test",
    srcs = ["custom_kernel_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":custom_kernel",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:compiler",
        "//xla/service:platform_util",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_matchers",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "fission",
    srcs = ["fission.cc"],
    hdrs = ["fission.h"],
    tags = ["gpu"],
    deps = [
        ":cublas",
        ":cublaslt",
        ":custom_kernel",
        ":gpu_codegen_backend",
        "//xla:autotuning_proto_cc",
        "//xla:util",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:call_inliner",
        "//xla/service:compiler",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_module_config",
        "//xla/service:hlo_proto_cc",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
        "//xla/service/gpu/transforms:dot_algorithm_rewriter",
        "//xla/service/gpu/transforms:gemm_rewriter",
        "//xla/service/gpu/transforms:priority_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
    ],
)

xla_test(
    name = "fission_test",
    srcs = ["fission_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",
        "no_mac",
    ],
    deps = [
        ":fission",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:compiler",
        "//xla/service:hlo_module_util",
        "//xla/service:platform_util",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "triton",
    srcs = ["triton.cc"],
    hdrs = ["triton.h"],
    tags = ["gpu"],
    deps = [
        ":gpu_codegen_backend",
        "//xla:autotuning_proto_cc",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/transforms/simplifiers:float_normalization",
        "//xla/hlo/utils:hlo_query",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:compiler",
        "//xla/service:hlo_cost_analysis",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_float_support",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:split_k_gemm_rewriter",
        "//xla/service/gpu/autotuning:dot_search_space",
        "//xla/service/gpu/transforms:fusion_wrapper",
        "//xla/service/gpu/transforms:nest_gemm_fusion",
        "//xla/service/gpu/transforms:priority_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:stream_executor_h",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
    ],
)

xla_test(
    name = "triton_test",
    srcs = ["triton_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
    ],
    tags = [
        "cuda-only",  # rocm support is not tested.
        "no_mac",
    ],
    deps = [
        ":triton",
        "//xla:autotuning_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:compiler",
        "//xla/service:executable",
        "//xla/service:platform_util",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:nvptx_compiler_impl",
        "//xla/stream_executor:device_description_proto_cc",
        "//xla/tsl/platform:status_matchers",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_matchers",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "factory",
    hdrs = ["factory.h"],
    tags = ["gpu"],
    deps = [
        ":cublas",
        ":triton",
        "//xla/backends/autotuner:codegen_backend",
        "//xla/service:compiler",
        "//xla/stream_executor:stream_executor_h",
    ],
)

cc_library(
    name = "gpu_profiler",
    srcs = ["gpu_profiler.cc"],
    hdrs = ["gpu_profiler.h"],
    deps = [
        "//xla:executable_run_options",
        "//xla:xla_data_proto_cc",
        "//xla/backends/autotuner:profiler",
        "//xla/hlo/ir:hlo",
        "//xla/service:executable",
        "//xla/service:maybe_owning_device_memory",
        "//xla/service/gpu:gpu_executable_run_options",
        "//xla/service/gpu/autotuning:redzone_buffers",
        "//xla/stream_executor:device_memory_allocator",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
    ],
)

xla_test(
    name = "gpu_profiler_test",
    srcs = ["gpu_profiler_test.cc"],
    backends = ["gpu"],
    deps = [
        ":gpu_profiler",
        "//xla:executable_run_options",
        "//xla:shape_util",
        "//xla/backends/autotuner:profiler",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:executable",
        "//xla/service:platform_util",
        "//xla/stream_executor:platform",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
    ],
)

xla_cc_binary(
    name = "autotuner_main",
    srcs = ["autotuner_main.cc"],
    tags = ["gpu"],
    deps = [
        ":factory",
        ":gpu_profiler",
        "//xla:debug_options_flags",
        "//xla/backends/autotuner",
        "//xla/backends/autotuner:profiler",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/parser:hlo_parser",
        "//xla/service:compiler",
        "//xla/service:gpu_plugin",
        "//xla/service:platform_util",
        "//xla/stream_executor:platform",
        "//xla/stream_executor:platform_manager",
        "//xla/stream_executor/cuda:cuda_platform_id",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util:command_line_flags",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@local_tsl//tsl/platform:platform_port",
    ],
)
