load("//tensorflow:tensorflow.bzl", "if_google")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [":friends"],
    licenses = ["notice"],
)

package_group(
    name = "friends",
    packages = [
        # Authorized users go here.
        # copybara:uncomment "//cloud/ai/platform/prediction/...",
        # copybara:uncomment "//learning/brain/experimental/tfrt/...",
        # copybara:uncomment "//learning/brain/tfrt/...",
        # copybara:uncomment "//learning/infra/mira/...",
        # copybara:uncomment "//learning/serving/...",
        # copybara:uncomment "//learning/pathways/serving/model_tests/...",
        # copybara:uncomment "//learning/pathways/serving/runtime/...",
        "//tensorflow/core/runtime_fallback/...",
        "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests/...",
        "//tensorflow/compiler/mlir/tfrt/...",
        "//tensorflow/core/tfrt/...",
        "//tensorflow_serving/...",
        "//tensorflow/core/tfrt/saved_model/python/...",
        # copybara:uncomment "//platforms/xla/tests/saved_models/...",
        # copybara:uncomment "//quality/webanswers/servo2/...",
        "//tensorflow/core/tfrt/saved_model/python",  #python wrapping
        "//tensorflow/core/tfrt/saved_model/utils/...",
        "//smartass/brain/inference/...",
    ],
)

cc_library(
    name = "saved_model_aot_compile",
    srcs = [
        "saved_model_aot_compile.cc",
    ],
    hdrs = ["saved_model_aot_compile.h"],
    deps = [
        ":saved_model_util",
        "//tensorflow/cc/saved_model:constants",
        "//tensorflow/compiler/jit:device_compilation_cluster_signature",
        "//tensorflow/compiler/jit:pjrt_device_compiler_client",
        "//tensorflow/compiler/jit:tf_graph_to_hlo_compiler",
        "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
        "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
        "//tensorflow/compiler/mlir/tfrt:import_model",
        "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options",
        "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options",
        "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:import_model",
        "//tensorflow/compiler/tf2xla:xla_compiler",
        "//tensorflow/core:core_cpu_base",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
        "//tensorflow/core/framework:attr_value_proto_cc",
        "//tensorflow/core/framework:function_proto_cc",
        "//tensorflow/core/framework:versions_proto_cc",
        "//tensorflow/core/ops",
        "//tensorflow/core/platform:enable_tf2_utils",
        "//tensorflow/core/platform:path",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/graph_executor",
        "//tensorflow/core/tfrt/graph_executor:export_mlir",
        "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
        "//tensorflow/core/tfrt/mlrt/bytecode",
        "//tensorflow/core/tfrt/runtime",
        "//tensorflow/core/tfrt/saved_model/utils:serialize_utils",
        "//tensorflow/core/tfrt/utils",
        "//tensorflow/core/tpu:virtual_device",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:env",
        "@local_tsl//tsl/platform:status",
        "@local_xla//xla/pjrt:pjrt_compiler",
        "@local_xla//xla/pjrt:pjrt_executable",
        "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client",
        "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_compiler",
        "@local_xla//xla/service:compiler",
        "@local_xla//xla/stream_executor/cuda:cuda_platform_id",
        "@tf_runtime//:bef",
        "@tf_runtime//:befexecutor",
        "@tf_runtime//:core_runtime_alwayslink",
        "@tf_runtime//:hostcontext",
    ],
)

cc_library(
    name = "saved_model_lib",
    srcs = [
        "saved_model.cc",
        "saved_model.h",
    ],
    visibility = ["//visibility:private"],
    deps = [
        ":saved_model_util",
        "//tensorflow/cc/saved_model:fingerprinting",
        "//tensorflow/cc/saved_model:reader",
        "//tensorflow/compiler/jit:flags_headers",
        "//tensorflow/compiler/mlir/tensorflow",
        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
        "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
        "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
        "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor",
        "//tensorflow/compiler/mlir/tfrt:import_model",
        "//tensorflow/compiler/mlir/tfrt:saved_model",
        "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options",
        "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:import_model",
        "//tensorflow/core:all_kernels",
        "//tensorflow/core:core_cpu_base",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
        "//tensorflow/core/common_runtime:device_mgr",
        "//tensorflow/core/framework:function_proto_cc",
        "//tensorflow/core/framework:graph_proto_cc",
        "//tensorflow/core/framework:tensor_proto_cc",
        "//tensorflow/core/ops",
        "//tensorflow/core/platform:enable_tf2_utils",
        "//tensorflow/core/platform:errors",
        "//tensorflow/core/platform:path",
        "//tensorflow/core/profiler/lib:traceme",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
        "//tensorflow/core/tfrt/graph_executor",
        "//tensorflow/core/tfrt/graph_executor:export_mlir",
        "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
        "//tensorflow/core/tfrt/mlrt/bytecode",
        "//tensorflow/core/tfrt/mlrt/bytecode:executable",
        "//tensorflow/core/tfrt/mlrt/interpreter:context",
        "//tensorflow/core/tfrt/mlrt/kernel",
        "//tensorflow/core/tfrt/mlrt/kernel:batch_kernel",
        "//tensorflow/core/tfrt/runtime",
        "//tensorflow/core/tfrt/runtime:work_queue_interface",
        "//tensorflow/core/tfrt/saved_model/utils:serialize_utils",
        "//tensorflow/core/tfrt/stubs:model_config_stub",
        "//tensorflow/core/tfrt/utils",
        "//tensorflow/core/tfrt/utils:error_util",
        "//tensorflow/core/tfrt/utils:fallback_tensor",
        "//tensorflow/core/tfrt/utils:tfrt_graph_execution_state",
        "@com_google_absl//absl/cleanup",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/synchronization",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@llvm-project//mlir:FuncExtensions",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:protobuf",
        "@local_xla//xla:status_macros",
        "@tf_runtime//:bef",
        "@tf_runtime//:befexecutor",
        "@tf_runtime//:core_runtime_alwayslink",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:init_tfrt_dialects",
        "@tf_runtime//:metrics",
        "@tf_runtime//:support",
    ] + if_google([
        "//third_party/tf_runtime_google:streamz_metrics_registry_alwayslink",
    ]),
)

cc_library(
    name = "saved_model_cpu",
    hdrs = ["saved_model.h"],
    visibility = ["//visibility:public"],
    deps = [
        ":saved_model_lib",
        ":saved_model_util",
        "//tensorflow/core/framework:graph_proto_cc",
        "//tensorflow/core/framework:tensor_proto_cc",
        "//tensorflow/core/platform:thread_annotations",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/graph_executor",
        "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
        "//tensorflow/core/tfrt/runtime",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@local_tsl//tsl/platform:protobuf",
        "@tf_runtime//:hostcontext",
    ],
)

cc_library(
    name = "saved_model",
    hdrs = ["saved_model.h"],
    visibility = ["//visibility:public"],
    deps = [
        ":saved_model_lib",
        ":saved_model_util",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "//tensorflow/core/framework:graph_proto_cc",
        "//tensorflow/core/framework:tensor_proto_cc",
        "//tensorflow/core/platform:thread_annotations",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "@local_tsl//tsl/platform:protobuf",
        # TODO(chky): Remove kernel fallback tensor deps.
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor_conversion_alwayslink",
        "//tensorflow/core/tfrt/gpu/kernel:gpurt_kernels",
        "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/graph_executor",
        "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
        "//tensorflow/core/tfrt/runtime",
        "@tf_runtime//:hostcontext",
    ] + if_google([
        "//learning/brain/tfrt/support:export_mlir",
        "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu",
        "//learning/brain/tfrt/saved_model:model_config_impl",
    ]),
)

cc_library(
    name = "saved_model_testutil",
    testonly = 1,
    srcs = ["saved_model_testutil.cc"],
    hdrs = ["saved_model_testutil.h"],
    tags = ["no_oss"],
    deps = [
        ":saved_model",
        "//tensorflow/cc/saved_model:loader",
        "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options",
        "//tensorflow/core:protos_all_cc",
        "//tensorflow/core/framework:tensor_testutil",
        "//tensorflow/core/tfrt/runtime",
        "@tf_runtime//:hostcontext",
    ],
)

cc_library(
    name = "saved_model_import_input",
    srcs = ["saved_model_import_input.cc"],
    hdrs = ["saved_model_import_input.h"],
    deps = [
        "//tensorflow/compiler/mlir/tensorflow:import_model",
        "//tensorflow/core:core_cpu_base",
        "//tensorflow/core:framework",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/graph_executor:config",
        "//tensorflow/core/tfrt/utils:tfrt_graph_execution_state",
        "@com_google_absl//absl/synchronization",
    ],
)

cc_library(
    name = "saved_model_util",
    srcs = ["saved_model_util.cc"],
    hdrs = ["saved_model_util.h"],
    deps = [
        ":saved_model_import_input",
        "//tensorflow/cc/saved_model:reader",
        "//tensorflow/compiler/mlir/tensorflow",
        "//tensorflow/compiler/mlir/tensorflow:import_model",
        "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
        "//tensorflow/compiler/mlir/tfrt:import_model",
        "//tensorflow/compiler/mlir/tfrt:saved_model",
        "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options",
        "//tensorflow/compiler/mlir/tfrt:transforms/gpu_passes",
        "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs",
        "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs",
        "//tensorflow/core:framework",
        "//tensorflow/core:framework_types_hdr",
        "//tensorflow/core:lib",
        "//tensorflow/core/framework:graph_proto_cc",
        "//tensorflow/core/framework:tensor",
        "//tensorflow/core/framework:tensor_proto_cc",
        "//tensorflow/core/platform:path",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "//tensorflow/core/tfrt/graph_executor:config",
        "//tensorflow/core/tfrt/mlrt/bytecode",
        "//tensorflow/core/tfrt/saved_model/utils:serialize_utils",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:protobuf",
        "@tf_runtime//:bef",
        "@tf_runtime//:init_tfrt_dialects",
    ],
)
