Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@ workspace(name = "org_tensorflow_text")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

# Toolchains for ML projects hermetic builds.
# Details: https://github.com/google-ml-infra/rules_ml_toolchain
http_archive(
name = "rules_ml_toolchain",
sha256 = "de3b14418657eeacd8afc2aa89608be6ec8d66cd6a5de81c4f693e77bc41bee1",
strip_prefix = "rules_ml_toolchain-5653e5a0ca87c1272069b4b24864e55ce7f129a1",
urls = [
"https://github.com/google-ml-infra/rules_ml_toolchain/archive/5653e5a0ca87c1272069b4b24864e55ce7f129a1.tar.gz",
],
)

load(
"@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl",
"cc_toolchain_deps",
)

cc_toolchain_deps()

register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64")

register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64_cuda")

http_archive(
name = "icu",
strip_prefix = "icu-release-64-2",
Expand Down Expand Up @@ -56,10 +78,10 @@ http_archive(

http_archive(
name = "org_tensorflow",
strip_prefix = "tensorflow-40998f44c0c500ce0f6e3b1658dfbc54f838a82a",
sha256 = "5a5bc4599964c71277dcac0d687435291e5810d2ac2f6283cc96736febf73aaf",
sha256 = "1a25308b15036bf8006ada5c9955ddc9a217792e6fc24deee04626ec07013f2c",
strip_prefix = "tensorflow-72fbba3d20f4616d7312b5e2b7f79daf6e82f2fa",
urls = [
"https://github.com/tensorflow/tensorflow/archive/40998f44c0c500ce0f6e3b1658dfbc54f838a82a.zip"
"https://github.com/tensorflow/tensorflow/archive/72fbba3d20f4616d7312b5e2b7f79daf6e82f2fa.zip",
],
)

Expand Down Expand Up @@ -134,6 +156,14 @@ load("@pypi//:requirements.bzl", "install_deps")

install_deps()

load("//oss_scripts/pip_package:tensorflow_text_python_wheel.bzl", "tensorflow_text_python_wheel_repository")

tensorflow_text_python_wheel_repository(
name = "tensorflow_text_wheel",
version_key = "__version__",
version_source = "//tensorflow_text:__init__.py",
)

# Initialize TensorFlow dependencies.
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
tf_workspace3()
Expand All @@ -151,14 +181,16 @@ load("@local_config_android//:android.bzl", "android_workspace")
android_workspace()

load(
"@local_xla//third_party/py:python_wheel.bzl",
"@org_tensorflow//third_party/xla/third_party/py:python_wheel.bzl",
"python_wheel_version_suffix_repository",
)

python_wheel_version_suffix_repository(name = "tf_wheel_version_suffix")
python_wheel_version_suffix_repository(
name = "tf_wheel_version_suffix",
)

load(
"@local_xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)

Expand All @@ -170,7 +202,7 @@ load(
"CUDNN_REDISTRIBUTIONS",
)
load(
"@local_xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
Expand All @@ -184,21 +216,21 @@ cudnn_redist_init_repository(
)

load(
"@local_xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)

cuda_configure(name = "local_config_cuda")

load(
"@local_xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)

nccl_redist_init_repository()

load(
"@local_xla//third_party/nccl/hermetic:nccl_configure.bzl",
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)

Expand Down
2 changes: 1 addition & 1 deletion oss_scripts/configure.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ else
if [[ "$IS_NIGHTLY" == "nightly" ]]; then
pip install tf-nightly
else
pip install tensorflow==2.18.0
pip install tensorflow==2.20.0
fi
fi

Expand Down
60 changes: 54 additions & 6 deletions oss_scripts/pip_package/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@org_tensorflow//third_party/xla/third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")

# Tools for building the TF.Text pip package.
load("@python//:defs.bzl", "compile_pip_requirements")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
load("//oss_scripts/pip_package:wheel.bzl", "tensorflow_text_wheel")

package(default_visibility = ["//visibility:private"])

Expand All @@ -27,14 +31,58 @@ py_binary(
],
)

sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = [
string_flag(
name = "output_path",
build_setting_default = "dist",
)

py_binary(
name = "build_wheel_py",
srcs = ["build_wheel.py"],
main = "build_wheel.py",
deps = [
#":build_utils",
#"@bazel_tools//tools/python/runfiles",
#"@pypi//build",
#"@pypi//setuptools",
#"@pypi//wheel",
],
)

filegroup(
name = "wheel_sources",
srcs = [
"LICENSE",
"MANIFEST.in",
"setup.nightly.py",
"setup.py",
"//tensorflow_text",
":transitive_data_deps",
":transitive_py_deps",
],
)

transitive_py_deps(
name = "transitive_py_deps",
deps = ["//tensorflow_text"],
)

collect_data_files(
name = "transitive_data_deps",
deps = ["//tensorflow_text"],
)

tensorflow_text_wheel(
name = "tensorflow_text_wheel",
srcs = [":wheel_sources"],
)

#sh_binary(
# name = "build_pip_package",
# srcs = ["build_pip_package.sh"],
# data = [
# "LICENSE",
# "MANIFEST.in",
# "setup.nightly.py",
# "setup.py",
# "//tensorflow_text",
# ],
#)
129 changes: 129 additions & 0 deletions oss_scripts/pip_package/build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2025 TF.Text Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script that builds a tf text wheel, intended to be run via bazel."""

import argparse
import os
import pathlib
import shutil
import subprocess
import sys
import tempfile

parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--srcs", help="source files for the wheel", action="append"
)
parser.add_argument(
"--platform",
default="",
required=False,
help="Platform name to be passed to setup.py",
)
args = parser.parse_args()


def copy_file(
src_file: str,
dst_dir: str,
) -> None:
"""Copy a file to the destination directory.

Args:
src_file: file to be copied
dst_dir: destination directory
"""

dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
os.makedirs(dest_dir_path, exist_ok=True)
shutil.copy(src_file, dest_dir_path)
os.chmod(os.path.join(dst_dir, src_file), 0o644)


def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
"""Filter the sources and copy them to the destination directory.

Args:
deps: a list of paths to files.
srcs_dir: target directory where files are copied to.
"""

for file in deps:
print(file)
if not (file.startswith("bazel-out") or file.startswith("external")):
copy_file(file, srcs_dir)


def build_wheel(
dir_path: str,
cwd: str,
platform: str,
) -> None:
"""Build the wheel in the target directory.

Args:
dir_path: directory where the wheel will be stored
cwd: path to directory with wheel source files
platform: platform name to pass to setup.py.
"""

subprocess.run(
[
sys.executable,
"setup.nightly.py",
"bdist_wheel",
f"--dist-dir={dir_path}",
f"--plat-name={platform}",
],
check=True,
cwd=cwd,
)


tmpdir = tempfile.TemporaryDirectory(prefix="tensorflow_text")
sources_path = tmpdir.name

try:
os.makedirs(args.output_path, exist_ok=True)
prepare_srcs(args.srcs, pathlib.Path(sources_path))
build_wheel(
os.path.join(os.getcwd(), args.output_path),
tmpdir.path,
args.platform,
)
finally:
if tmpdir:
tmpdir.cleanup()
8 changes: 4 additions & 4 deletions oss_scripts/pip_package/requirements.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
setuptools==70.0.0
dm-tree==0.1.8 # Limit for macos support.
numpy
protobuf==4.25.3 # b/397977335 - Fix crash on python 3.9, 3.10.
tensorflow
#protobuf==4.25.3 # b/397977335 - Fix crash on python 3.9, 3.10.
tensorflow==2.20.0
tf-keras
tensorflow-datasets
tensorflow-metadata
#tensorflow-datasets
#tensorflow-metadata
40 changes: 40 additions & 0 deletions oss_scripts/pip_package/tensorflow_text_python_wheel.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Repository rule to generate a file with TF text wheel version.
def _tensorflow_text_python_wheel_repository_impl(repository_ctx):
version_source = repository_ctx.attr.version_source
version_key = repository_ctx.attr.version_key
version_file_content = repository_ctx.read(
repository_ctx.path(version_source),
)
version_start_index = version_file_content.find(version_key)
version_end_index = version_start_index + version_file_content[version_start_index:].find("\n")
wheel_version = version_file_content[version_start_index:version_end_index].replace(
version_key,
"WHEEL_VERSION",
)
repository_ctx.file(
"wheel.bzl",
wheel_version,
)
repository_ctx.file("BUILD", "")

tensorflow_text_python_wheel_repository = repository_rule(
implementation = _tensorflow_text_python_wheel_repository_impl,
attrs = {
"version_source": attr.label(mandatory = True, allow_single_file = True),
"version_key": attr.string(mandatory = True),
},
)
Loading