load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//xla:xla.default.bzl", "xla_cc_test")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = ["//xla/python/ifrt:users"],
    licenses = ["notice"],
)

gentbl_cc_library(
    name = "passes_inc_gen",
    compatible_with = get_compatible_with_portable(),
    tbl_outs = {"passes.h.inc": [
        "-gen-pass-decls",
        "-name=IfrtIr",
    ]},
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "passes.td",
    deps = [
        "@llvm-project//mlir:PassBaseTdFiles",
    ],
)

cc_library(
    name = "passes",
    srcs = [
        "ifrt_atom_programs_from_vhlo_pass.cc",
        "ifrt_atom_programs_to_vhlo_pass.cc",
        "ifrt_compile_and_propagate_shardings_pass.cc",
        "ifrt_compile_atom_program_pass.cc",
        "ifrt_dump_atom_programs_pass.cc",
        "ifrt_duplicated_callee_elimination_pass.cc",
        "ifrt_legalize_to_vifrt_pass.cc",
        "ifrt_lower_atom_program_metadata_to_xla_pass.cc",
        "ifrt_lower_mpmd_reshard_to_call_pass.cc",
        "ifrt_merge_reshards_pass.cc",
        "ifrt_outline_atom_program_to_module_pass.cc",
        "ifrt_populate_atom_program_metadata_pass.cc",
        "ifrt_precompile_atom_program_preprocessing_pass.cc",
        "ifrt_remove_attrs_from_other_dialects_pass.cc",
        "ifrt_remove_ifrt_attrs_pass.cc",
        "ifrt_reshard_to_copy_arrays_pass.cc",
        "ifrt_to_dot_pass.cc",
        "ifrt_verify_bound_external_loaded_executable_pass.cc",
        "ifrt_verify_device_type_consistency_pass.cc",
        "ifrt_verify_donation_pass.cc",
        "ifrt_verify_sharding_specified_pass.cc",
        "multi_threaded_atom_program_compiler.cc",
        "passes.cc",
        "spmd_expandable_interface_verification_pass.cc",
        "spmd_expansion_pass.cc",
        "vifrt_legalize_to_ifrt_pass.cc",
        "vifrt_to_version_pass.cc",
    ],
    hdrs = [
        "map_ifrt_to_vifrt.h",
        "multi_threaded_atom_program_compiler.h",
        "passes.h",
    ],
    compatible_with = get_compatible_with_portable(),
    deps = [
        ":debug",
        ":passes_inc_gen",
        ":utils",
        "//xla:status_macros",
        "//xla:xla_data_proto_cc",
        "//xla/client:executable_build_options",
        "//xla/hlo/ir:hlo",
        "//xla/mlir_hlo",
        "//xla/pjrt:pjrt_compiler",
        "//xla/pjrt:pjrt_executable",
        "//xla/python/ifrt",
        "//xla/python/ifrt:attribute_map",
        "//xla/python/ifrt/hlo:hlo_program",
        "//xla/python/ifrt/ir",
        "//xla/python/ifrt/ir:atom_program_compiler",
        "//xla/python/ifrt/ir:ifrt_ir_program",
        "//xla/python/ifrt/ir:ifrt_ir_program_proto_cc",
        "//xla/python/ifrt/ir:ifrt_ops_inc_gen",
        "//xla/python/ifrt/ir:sharding_param",
        "//xla/python/ifrt/ir:version",
        "//xla/python/ifrt/ir:vifrt",
        "//xla/python/ifrt/support:sharding_conversions",
        "//xla/python/pjrt_ifrt:xla_ifrt",
        "//xla/service:compilation_environments",
        "//xla/service:computation_placer_hdr",
        "//xla/service:hlo_proto_cc",
        "//xla/service/spmd/shardy:constants",
        "//xla/service/spmd/shardy:utils",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:status",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/cleanup",
        "@com_google_absl//absl/container:btree",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:AllPassesAndDialects",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Rewrite",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TransformUtils",
        "@llvm-project//mlir:Transforms",
        "@local_tsl//tsl/platform:fingerprint",
        "@local_tsl//tsl/platform:numbers",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:protobuf",
        "@shardy//shardy/dialect/sdy/ir:dialect",
        "@stablehlo//:register",
        "@stablehlo//:stablehlo_ops",
        "@stablehlo//:stablehlo_serialization",
    ],
)

cc_library(
    name = "built_in_spmd_expansions",
    srcs = ["built_in_spmd_expansions.cc"],
    hdrs = ["built_in_spmd_expansions.h"],
    compatible_with = get_compatible_with_portable(),
    deps = [
        "//xla/python/ifrt/ir/transforms/spmd_expanders:spmd_expander",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
    ],
)

cc_library(
    name = "utils",
    srcs = ["utils.cc"],
    hdrs = ["utils.h"],
    compatible_with = get_compatible_with_portable(),
    deps = [
        "//xla:xla_data_proto_cc",
        "//xla/mlir/utils:type_util",
        "//xla/python/ifrt",
        "//xla/python/ifrt/ir",
        "//xla/python/pjrt_ifrt:pjrt_dtype",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:fingerprint",
    ],
)

cc_library(
    name = "debug",
    srcs = ["debug.cc"],
    hdrs = ["debug.h"],
    compatible_with = get_compatible_with_portable(),
    deps = [
        "//xla/tsl/platform:env",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:random",
        "@local_tsl//tsl/platform:regexp",
    ],
)

xla_cc_test(
    name = "debug_test",
    srcs = ["debug_test.cc"],
    deps = [
        ":debug",
        "//xla/python/ifrt/support:module_parsing",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:path",
        "@local_tsl//tsl/platform:status_matchers",
    ],
)
