load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = ["//tensorflow/core/tfrt/saved_model:__pkg__"],
    licenses = ["notice"],
)

cc_library(
    name = "gpurt_kernels",
    srcs = ["gpurt_kernels.cc"],
    deps = [
        ":gpu_runner",
        "//tensorflow/core:framework",
        "//tensorflow/core/framework:tensor",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/runtime_fallback/kernel:tensor_util",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "//tensorflow/core/tfrt/utils:gpu_variables_table",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:support",
        "@tf_runtime//:tensor_alwayslink",
    ],
    alwayslink = True,
)

cc_library(
    name = "gpu_runner",
    srcs = ["gpu_runner.cc"],
    hdrs = ["gpu_runner.h"],
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/compiler/jit:pjrt_compile_util",
        "//tensorflow/compiler/jit:pjrt_tensor_buffer_util",
        "//tensorflow/compiler/jit:xla_compile_util",
        "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
        "//tensorflow/compiler/jit:xla_launch_util",
        "//tensorflow/compiler/tf2xla:xla_compiler",
        "//tensorflow/core:framework",
        "//tensorflow/core/framework:attr_value_proto_cc",
        "//tensorflow/core/framework:function_proto_cc",
        "//tensorflow/core/framework:types_proto_cc",
        "//tensorflow/core/platform:notification",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/tfrt/common:global_state",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "//tensorflow/core/tfrt/utils:gpu_variables_table",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@local_tsl//tsl/platform:fingerprint",
        "@local_tsl//tsl/platform:protobuf",
        "@local_xla//xla/pjrt:pjrt_client",
        "@local_xla//xla/pjrt:pjrt_common",
        "@local_xla//xla/tsl/framework:device_id",
        "@local_xla//xla/tsl/framework:serving_device_selector",
        "@local_xla//xla/tsl/platform:errors",
        "@local_xla//xla/tsl/platform:statusor",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:support",
        "@tf_runtime//:tensor_alwayslink",
    ],
)

tf_cuda_cc_test(
    name = "gpu_runner_test",
    srcs = ["gpu_runner_test.cc"],
    tags = [
        "gpu",  # Only enables test on GPU.
        "no_oss",  # This test only runs with GPU.
        "noasan",
        "nomsan",
        "noopt",
        "notsan",
    ],
    deps = [
        ":gpu_runner",
        "//tensorflow/cc:function_ops",
        "//tensorflow/cc:math_ops",
        "//tensorflow/cc:scope",
        "//tensorflow/compiler/jit:xla_gpu_jit",
        "//tensorflow/core:framework",
        "//tensorflow/core:test",
        "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector",
        "//tensorflow/core/framework:tensor_testutil",
        "//tensorflow/core/kernels:ops_testutil",
        "//tensorflow/core/platform:errors",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/platform:status_matchers",
        "//tensorflow/core/platform:statusor",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/tfrt/common:pjrt_util",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "@com_google_googletest//:gtest_main",
        "@local_xla//xla/tsl/framework:serving_device_selector_policies",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:tensor",
    ],
)

cc_library(
    name = "tfrt_gpu_init",
    srcs = ["tfrt_gpu_init.cc"],
    hdrs = ["tfrt_gpu_init.h"],
    visibility = ["//visibility:public"],
    deps = [
        ":gpu_runner",
        "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/tfrt/runtime",
        "@com_google_absl//absl/status",
        "@local_xla//xla/tsl/framework:serving_device_selector_policies",
        "@tf_runtime//:hostcontext",
    ],
)
