load("//xla:xla.default.bzl", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "xla_test")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [":friends"],
    licenses = ["notice"],
)

package_group(
    name = "friends",
    includes = [
        "//xla:friends",
    ],
)

cc_library(
    name = "copy",
    srcs = ["copy.cc"],
    hdrs = ["copy.h"],
    deps = [
        ":fusion_emitter",
        "//xla:literal_util",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/backends/gpu/runtime:copy_thunk",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:buffer_assignment",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter_context",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "copy_test",
    srcs = ["copy_test.cc"],
    deps = [
        ":copy",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:path",
    ],
)

cc_library(
    name = "cudnn",
    srcs = ["cudnn.cc"],
    hdrs = ["cudnn.h"],
    deps = [
        ":fusion_emitter",
        "//xla/backends/gpu/runtime:cudnn_thunk",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/codegen/emitters:kernel_arguments",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_constants",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status:statusor",
    ],
)

xla_test(
    name = "cudnn_test",
    srcs = ["cudnn_test.cc"],
    backends = ["gpu"],
    tags = ["cuda-only"],
    deps = [
        "//xla:comparison_util",
        "//xla:debug_options_flags",
        "//xla:error_spec",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:dump",
        "//xla/service:executable",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:cudnn_support_utils",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/service/gpu/transforms:cudnn_fusion_compiler",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:dnn",
        "//xla/stream_executor:stream_executor_h",
        "//xla/stream_executor:stream_executor_memory_allocator",
        "//xla/tsl/lib/core:status_test_util",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:env",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:statusor",
        "@local_tsl//tsl/platform:test",
    ],
)

cc_library(
    name = "custom",
    srcs = ["custom.cc"],
    hdrs = ["custom.h"],
    deps = [
        ":fusion_emitter",
        "//xla:literal",
        "//xla:literal_util",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla/backends/gpu/collectives:gpu_clique_key",
        "//xla/backends/gpu/runtime:all_reduce_thunk",
        "//xla/backends/gpu/runtime:collective_thunk",
        "//xla/backends/gpu/runtime:copy_thunk",
        "//xla/backends/gpu/runtime:custom_call_target",
        "//xla/backends/gpu/runtime:custom_call_thunk",
        "//xla/backends/gpu/runtime:dynamic_slice_thunk",
        "//xla/backends/gpu/runtime:gemm_thunk",
        "//xla/backends/gpu/runtime:kernel_thunk",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/codegen/emitters:kernel_arguments",
        "//xla/ffi:attribute_map",
        "//xla/ffi:ffi_api",
        "//xla/hlo/analysis:while_loop_analysis",
        "//xla/hlo/evaluator:hlo_evaluator",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:buffer_assignment",
        "//xla/service:custom_call_status",
        "//xla/service:custom_call_target_registry",
        "//xla/service:hlo_proto_cc",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:cublas_cudnn",
        "//xla/service/gpu:gpu_constants",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu:stream_executor_util",
        "//xla/service/gpu/kernels:custom_kernel",
        "//xla/service/gpu/kernels:custom_kernel_fusion",
        "//xla/stream_executor:stream",
        "//xla/tools:hlo_extractor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:AsmParser",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:statusor",
    ],
)

xla_test(
    name = "dynamic_slice_fusion_test",
    srcs = ["dynamic_slice_fusion_test.cc"],
    backend_tags = {
        "gpu": [
            "multi_gpu",
            "no_oss",
        ],
    },
    backends = ["gpu"],
    deps = [
        "//xla:error_spec",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/runtime:dynamic_slice_thunk",
        "//xla/backends/gpu/runtime:sequential_thunk",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/backends/gpu/runtime:while_thunk",
        "//xla/ffi",
        "//xla/ffi:ffi_api",
        "//xla/hlo/builder:xla_builder",
        "//xla/hlo/builder/lib:constants",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/service:executable",
        "//xla/service:hlo_module_config",
        "//xla/service:hlo_proto_cc",
        "//xla/service:hlo_runner_interface",
        "//xla/service:platform_util",
        "//xla/service/gpu:gpu_executable",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
        "//xla/stream_executor:device_memory",
        "//xla/stream_executor:stream",
        "//xla/tests:hlo_test_base",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "fusion_emitter",
    srcs = ["fusion_emitter.cc"],
    hdrs = ["fusion_emitter.h"],
    visibility = [
        "//xla/backends/gpu/codegen:__subpackages__",
        "//xla/service/gpu:__subpackages__",
    ],
    deps = [
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/codegen/emitters:kernel_api_builder",
        "//xla/codegen/emitters:kernel_arguments",
        "//xla/hlo/analysis:indexing_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/runtime:work_dimensions",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu:target_util",
        "//xla/service/llvm_ir:ir_array",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:LLVMDialect",
    ],
)

cc_library(
    name = "fusions",
    srcs = ["fusions.cc"],
    hdrs = ["fusions.h"],
    visibility = [
        "//xla/backends/gpu/codegen:__subpackages__",
        "//xla/service/gpu:__subpackages__",
    ],
    deps = [
        ":copy",
        ":cudnn",
        ":custom",
        ":fusion_emitter",
        "//xla:shape_util",
        "//xla/backends/gpu/codegen/emitters:concatenate",
        "//xla/backends/gpu/codegen/emitters:in_place_dynamic_update_slice",
        "//xla/backends/gpu/codegen/emitters:input_slices",
        "//xla/backends/gpu/codegen/emitters:loop",
        "//xla/backends/gpu/codegen/emitters:reduction",
        "//xla/backends/gpu/codegen/emitters:scatter",
        "//xla/backends/gpu/codegen/emitters:transpose",
        "//xla/backends/gpu/codegen/triton:fusion",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service:buffer_assignment",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
    ],
)
