load(
    "//tensorflow:tensorflow.bzl",
    "tf_cc_shared_test",
)

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [
        # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__",
        # copybara:uncomment "//learning/brain/tfrt:__subpackages__",
        "//tensorflow/core/tfrt/graph_executor:__subpackages__",
        "//tensorflow/core/tfrt/ifrt:__subpackages__",
        "//tensorflow/core/tfrt/saved_model:__subpackages__",
        "//tensorflow/core/tfrt/tfrt_session:__subpackages__",
    ],
)

cc_library(
    name = "kernel",
    srcs = ["kernel.cc"],
    hdrs = ["kernel.h"],
    deps = [
        ":context",
        ":kernel_runner_utils",
        "//tensorflow/core:framework",
        "//tensorflow/core/framework:tensor_proto_cc",
        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
        "//tensorflow/core/tfrt/mlrt/bytecode:function",
        "//tensorflow/core/tfrt/mlrt/interpreter:async_handle",
        "//tensorflow/core/tfrt/mlrt/interpreter:attribute_span",
        "//tensorflow/core/tfrt/mlrt/interpreter:builtin_kernels",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "//tensorflow/core/tfrt/mlrt/interpreter:execute",
        "//tensorflow/core/tfrt/mlrt/interpreter:future",
        "//tensorflow/core/tfrt/mlrt/interpreter:register_span",
        "//tensorflow/core/tfrt/mlrt/interpreter:value",
        "//tensorflow/core/tfrt/utils",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/profiler/lib:traceme",
        "@tf_runtime//:hostcontext",
    ],
)

cc_library(
    name = "shard_restore_util",
    srcs = ["shard_restore_util.cc"],
    hdrs = ["shard_restore_util.h"],
    deps = [
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
    ],
)

cc_library(
    name = "ifrt_ops_kernel",
    srcs = ["ifrt_ops_kernel.cc"],
    deps = [
        ":context",
        ":kernel",
        "//tensorflow/core:framework",
        "//tensorflow/core:framework_internal",
        "//tensorflow/core/framework:attr_value_proto_cc",
        "//tensorflow/core/framework:types_proto_cc",
        "//tensorflow/core/platform:protobuf",
        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
        "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc",
        "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_utils",
        "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
        "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry",
        "//tensorflow/core/tfrt/mlrt/bytecode",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "//tensorflow/core/tfrt/mlrt/interpreter:future",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@local_tsl//tsl/platform:errors",
        "@local_tsl//tsl/platform:tstring",
        "@local_xla//xla:xla_data_proto_cc",
        "@local_xla//xla/python/ifrt",
    ],
    alwayslink = 1,
)

cc_library(
    name = "batch_kernel",
    srcs = ["batch_kernel.cc"],
    hdrs = ["batch_kernel.h"],
    deps = [
        ":context",
        ":kernel_runner_utils",
        "//tensorflow/core:framework",
        "//tensorflow/core/kernels/batching_util:batch_resource_base",
        "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/platform:statusor",
        "//tensorflow/core/runtime_fallback/runtime:fallback_batch_kernel",
        "//tensorflow/core/tfrt/fallback:op_kernel_runner_cache",
        "//tensorflow/core/tfrt/mlrt/bytecode:span",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "//tensorflow/core/tfrt/mlrt/interpreter:execute",
        "//tensorflow/core/tfrt/mlrt/interpreter:register_span",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/strings",
        "@com_google_protobuf//:protobuf_headers",
        "@local_tsl//tsl/profiler/lib:connected_traceme",
        "@local_tsl//tsl/profiler/lib:context_types_hdrs",
        "@tf_runtime//:async_value",
        "@tf_runtime//:hostcontext",
    ],
)

cc_library(
    name = "kernel_runner_utils",
    srcs = ["kernel_runner_utils.cc"],
    hdrs = ["kernel_runner_utils.h"],
    deps = [
        ":context",
        "//tensorflow/core:framework",
        "//tensorflow/core/profiler/lib:traceme",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_utils",
        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "//tensorflow/core/tfrt/mlrt/interpreter:future",
        "//tensorflow/core/tfrt/mlrt/interpreter:register_span",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/cleanup",
    ],
)

cc_library(
    name = "context",
    hdrs = ["context.h"],
    deps = [
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "@tf_runtime//:hostcontext",
    ],
)

tf_cc_shared_test(
    name = "kernel_test",
    srcs = [
        "kernel_test.cc",
    ],
    tags = ["no_oss"],
    deps = [
        ":batch_kernel",
        ":context",
        ":kernel",
        "//tensorflow/core:framework",
        "//tensorflow/core/framework:tensor_testutil",
        "//tensorflow/core/kernels:math",
        "//tensorflow/core/ops:math_ops_op_lib",
        "//tensorflow/core/tfrt/fallback:device_with_custom_allocator",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/mlrt/bytecode:executable",
        "//tensorflow/core/tfrt/mlrt/interpreter:execute",
        "//tensorflow/core/tfrt/mlrt/interpreter:future",
        "//tensorflow/core/tfrt/mlrt/interpreter:interpreter_testutil",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:status_matchers",
        "@local_xla//xla/tsl/lib/core:status_test_util",
        "@local_xla//xla/tsl/platform:status_matchers",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:ref_count",
    ],
)

tf_cc_shared_test(
    name = "batch_kernel_test",
    srcs = ["batch_kernel_test.cc"],
    deps = [
        ":batch_kernel",
        ":context",
        ":kernel",
        "//tensorflow/core:framework",
        "//tensorflow/core/framework:tensor_testutil",
        "//tensorflow/core/kernels:math",
        "//tensorflow/core/ops:math_ops_op_lib",
        "//tensorflow/core/tfrt/fallback:device_with_custom_allocator",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/mlrt/bytecode",
        "//tensorflow/core/tfrt/mlrt/bytecode:executable",
        "//tensorflow/core/tfrt/mlrt/interpreter:execute",
        "//tensorflow/core/tfrt/mlrt/interpreter:interpreter_testutil",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:status_matchers",
        "@local_xla//xla/tsl/lib/core:status_test_util",
        "@local_xla//xla/tsl/platform:status_matchers",
        "@tf_runtime//:hostcontext",
    ],
)

tf_cc_shared_test(
    name = "ifrt_ops_kernel_test",
    srcs = ["ifrt_ops_kernel_test.cc"],
    data = [
        "//tensorflow/core/tfrt/mlrt/kernel/testdata",
    ],
    tags = ["no_oss"],
    deps = [
        ":context",
        ":ifrt_ops_kernel",
        ":kernel",
        "//tensorflow/core:core_cpu",
        "//tensorflow/core:framework_internal",
        "//tensorflow/core/framework:tensor",
        "//tensorflow/core/framework:tensor_matcher",
        "//tensorflow/core/framework:tensor_testutil",
        "//tensorflow/core/kernels:array",
        "//tensorflow/core/kernels:io",
        "//tensorflow/core/kernels:math",
        "//tensorflow/core/ops:math_ops_op_lib",
        "//tensorflow/core/platform:protobuf",
        "//tensorflow/core/platform:resource_loader",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
        "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc",
        "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_registry",
        "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
        "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry",
        "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector",
        "//tensorflow/core/tfrt/mlrt/bytecode",
        "//tensorflow/core/tfrt/mlrt/bytecode:executable",
        "//tensorflow/core/tfrt/mlrt/interpreter:builtin_kernels",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "//tensorflow/core/tfrt/mlrt/interpreter:execute",
        "//tensorflow/core/tfrt/mlrt/interpreter:interpreter_testutil",
        "//tensorflow/core/tfrt/mlrt/interpreter:value",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
        "@eigen_archive//:eigen3",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:env",
        "@local_tsl//tsl/platform:refcount",
        "@local_tsl//tsl/platform:status",
        "@local_tsl//tsl/platform:status_matchers",
        "@local_tsl//tsl/platform:statusor",
        "@local_tsl//tsl/platform:tstring",
        "@local_xla//xla/python/ifrt",
        "@local_xla//xla/python/ifrt:test_util",
        "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
        "@local_xla//xla/tsl/framework:serving_device_selector",
        "@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector",
        "@local_xla//xla/tsl/lib/core:status_test_util",
        "@tf_runtime//:hostcontext",
    ],
)

tf_cc_shared_test(
    name = "shard_restore_util_test",
    srcs = ["shard_restore_util_test.cc"],
    tags = ["no_oss"],
    deps = [
        ":shard_restore_util",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
    ],
)
