Skip to content

Commit 6228366

Browse files
authored
Repo sync (secretflow#366)
# Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility:
1 parent 815c15e commit 6228366

File tree

254 files changed

+1783
-1496
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

254 files changed

+1783
-1496
lines changed

benchmark/binary_op_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
import argparse
1616
import json
17+
import time
1718

1819
import jax.numpy as jnp
1920
import numpy as np
20-
import time
2121

2222
import spu.utils.distributed as ppd
2323

benchmark/unary_op_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
import argparse
1616
import json
17+
import time
1718

1819
import jax.numpy as jnp
1920
import numpy as np
20-
import time
2121

2222
import spu.utils.distributed as ppd
2323

docs/reference/gen_benchmark_report.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# limitations under the License.
1616

1717

18+
import argparse
1819
import json
19-
import pandas as pd
20-
import numpy as np
2120
import os
2221
from enum import Enum
23-
import argparse
22+
23+
import numpy as np
24+
import pandas as pd
2425

2526
g_time_list = [
2627
('ns', 1000),

docs/reference/gen_complexity_md.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717

1818
import argparse
19-
from pytablewriter import MarkdownTableWriter
2019
import json
2120

21+
from pytablewriter import MarkdownTableWriter
22+
2223

2324
def main():
2425
parser = argparse.ArgumentParser(

docs/reference/gen_np_op_status_doc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717

1818
import argparse
19-
from pytablewriter import MarkdownTableWriter
2019
import json
20+
2121
from mdutils.mdutils import MdUtils
22+
from pytablewriter import MarkdownTableWriter
2223

2324

2425
def main():

docs/tutorials/cpp_lr_example.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ In the first terminal.
2020

2121
.. code-block:: bash
2222
23-
bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/data/perfect_logit_a.csv -has_label=true
23+
bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/perfect_logit_a.csv -has_label=true
2424
2525
In the second terminal.
2626

2727
.. code-block:: bash
2828
29-
bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/data/perfect_logit_b.csv
29+
bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/perfect_logit_b.csv
3030

examples/cpp/BUILD.bazel

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ spu_cc_binary(
2525
deps = [
2626
":utils",
2727
"//libspu/device:io",
28-
"//libspu/kernel/hal",
28+
"//libspu/kernel/hal:public_helper",
29+
"//libspu/kernel/hlo:basic_binary",
30+
"//libspu/kernel/hlo:basic_unary",
31+
"//libspu/kernel/hlo:casting",
32+
"//libspu/kernel/hlo:const",
33+
"//libspu/kernel/hlo:geometrical",
2934
"@com_google_absl//absl/strings",
3035
"@llvm-project//llvm:Support",
3136
"@yacl//yacl/link:factory",

examples/cpp/simple_lr.cc

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,51 @@
2424

2525
#include "examples/cpp/utils.h"
2626
#include "spdlog/spdlog.h"
27+
#include "xtensor/xarray.hpp"
2728
#include "xtensor/xcsv.hpp"
29+
#include "xtensor/xview.hpp"
2830

2931
#include "libspu/device/io.h"
30-
#include "libspu/kernel/hal/hal.h"
31-
#include "libspu/kernel/hal/type_cast.h"
32+
#include "libspu/kernel/hal/public_helper.h"
33+
#include "libspu/kernel/hlo/basic_binary.h"
34+
#include "libspu/kernel/hlo/basic_unary.h"
35+
#include "libspu/kernel/hlo/casting.h"
36+
#include "libspu/kernel/hlo/const.h"
37+
#include "libspu/kernel/hlo/geometrical.h"
3238
#include "libspu/mpc/factory.h"
3339

3440
using namespace spu::kernel;
3541

3642
spu::Value train_step(spu::SPUContext* ctx, const spu::Value& x,
3743
const spu::Value& y, const spu::Value& w) {
3844
// Padding x
39-
auto padding = hal::constant(ctx, 1.0F, spu::DT_F32, {x.shape()[0], 1});
40-
auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1);
41-
auto pred = hal::logistic(ctx, hal::matmul(ctx, padded_x, w));
45+
auto padding = hlo::Constant(ctx, 1.0F, {x.shape()[0], 1});
46+
auto padded_x = hlo::Concatenate(
47+
ctx, {x, hlo::Cast(ctx, padding, spu::VIS_SECRET, padding.dtype())}, 1);
48+
auto pred = hlo::Logistic(ctx, hlo::Dot(ctx, padded_x, w));
4249

4350
SPDLOG_DEBUG("[SSLR] Err = Pred - Y");
44-
auto err = hal::sub(ctx, pred, y);
51+
auto err = hlo::Sub(ctx, pred, y);
4552

4653
SPDLOG_DEBUG("[SSLR] Grad = X.t * Err");
47-
auto grad = hal::matmul(ctx, hal::transpose(ctx, padded_x), err);
54+
auto grad = hlo::Dot(ctx, hlo::Transpose(ctx, padded_x, {}), err);
4855

4956
SPDLOG_DEBUG("[SSLR] Step = LR / B * Grad");
50-
auto lr = hal::constant(ctx, 0.0001F, spu::DT_F32);
51-
auto msize =
52-
hal::constant(ctx, static_cast<float>(y.shape()[0]), spu::DT_F32);
53-
auto p1 = hal::mul(ctx, lr, hal::reciprocal(ctx, msize));
54-
auto step = hal::mul(ctx, hal::broadcast_to(ctx, p1, grad.shape()), grad);
57+
auto lr = hlo::Constant(ctx, 0.0001F, {});
58+
auto msize = hlo::Constant(ctx, static_cast<float>(y.shape()[0]), {});
59+
auto p1 = hlo::Mul(ctx, lr, hlo::Reciprocal(ctx, msize));
60+
auto step = hlo::Mul(ctx, hlo::Broadcast(ctx, p1, grad.shape(), {}), grad);
5561

5662
SPDLOG_DEBUG("[SSLR] W = W - Step");
57-
auto new_w = hal::sub(ctx, w, step);
63+
auto new_w = hlo::Sub(ctx, w, step);
5864

5965
return new_w;
6066
}
6167

6268
spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y,
6369
size_t num_epoch, size_t bsize) {
6470
const size_t num_iter = x.shape()[0] / bsize;
65-
auto w = hal::constant(ctx, 0.0F, spu::DT_F32, {x.shape()[1] + 1, 1});
71+
auto w = hlo::Constant(ctx, 0.0F, {x.shape()[1] + 1, 1});
6672

6773
// Run train loop
6874
for (size_t epoch = 0; epoch < num_epoch; ++epoch) {
@@ -73,10 +79,10 @@ spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y,
7379
const int64_t rows_end = rows_beg + bsize;
7480

7581
const auto x_slice =
76-
hal::slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {});
82+
hlo::Slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {});
7783

7884
const auto y_slice =
79-
hal::slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {});
85+
hlo::Slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {});
8086

8187
w = train_step(ctx, x_slice, y_slice, w);
8288
}
@@ -87,9 +93,10 @@ spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y,
8793

8894
spu::Value inference(spu::SPUContext* ctx, const spu::Value& x,
8995
const spu::Value& weight) {
90-
auto padding = hal::constant(ctx, 1.0F, spu::DT_F32, {x.shape()[0], 1});
91-
auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1);
92-
return hal::matmul(ctx, padded_x, weight);
96+
auto padding = hlo::Constant(ctx, 1.0F, {x.shape()[0], 1});
97+
auto padded_x = hlo::Concatenate(
98+
ctx, {x, hlo::Cast(ctx, padding, spu::VIS_SECRET, padding.dtype())}, 1);
99+
return hlo::Dot(ctx, padded_x, weight);
93100
}
94101

95102
float SSE(const xt::xarray<float>& y_true, const xt::xarray<float>& y_pred) {
@@ -143,7 +150,7 @@ std::pair<spu::Value, spu::Value> infeed(spu::SPUContext* sctx,
143150
auto x = cio.deviceGetVar("x-0");
144151
// Concatenate all slices
145152
for (size_t idx = 1; idx < cio.getWorldSize(); ++idx) {
146-
x = hal::concatenate(sctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))},
153+
x = hlo::Concatenate(sctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))},
147154
1);
148155
}
149156
auto y = cio.deviceGetVar("label");
@@ -175,10 +182,11 @@ int main(int argc, char** argv) {
175182

176183
const auto scores = inference(sctx.get(), x, w);
177184

178-
xt::xarray<float> revealed_labels =
179-
hal::dump_public_as<float>(sctx.get(), hal::reveal(sctx.get(), y));
180-
xt::xarray<float> revealed_scores =
181-
hal::dump_public_as<float>(sctx.get(), hal::reveal(sctx.get(), scores));
185+
xt::xarray<float> revealed_labels = hal::dump_public_as<float>(
186+
sctx.get(), hlo::Cast(sctx.get(), y, spu::VIS_PUBLIC, y.dtype()));
187+
xt::xarray<float> revealed_scores = hal::dump_public_as<float>(
188+
sctx.get(),
189+
hlo::Cast(sctx.get(), scores, spu::VIS_PUBLIC, scores.dtype()));
182190

183191
auto mse = MSE(revealed_labels, revealed_scores);
184192
std::cout << "MSE = " << mse << "\n";

examples/cpp/simple_pphlo.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
// clang-format on
2020

2121
#include "examples/cpp/utils.h"
22-
#include "spdlog/spdlog.h"
2322

2423
#include "libspu/device/api.h"
2524
#include "libspu/device/io.h"

examples/cpp/utils.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
#include "examples/cpp/utils.h"
1616

17-
#include "absl/strings/match.h"
1817
#include "absl/strings/str_split.h"
18+
#include "yacl/link/factory.h"
1919

2020
#include "libspu/core/config.h"
2121

@@ -41,7 +41,7 @@ std::shared_ptr<yacl::link::Context> MakeLink(const std::string& parties,
4141
std::vector<std::string> hosts = absl::StrSplit(parties, ',');
4242
for (size_t rank = 0; rank < hosts.size(); rank++) {
4343
const auto id = fmt::format("party{}", rank);
44-
lctx_desc.parties.push_back({id, hosts[rank]});
44+
lctx_desc.parties.emplace_back(id, hosts[rank]);
4545
}
4646
auto lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank);
4747
lctx->ConnectToMesh();

examples/python/ir_dump/ir_dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
import jax.numpy as jnp
2929
import numpy as np
3030

31-
import spu.utils.distributed as ppd
3231
import spu.spu_pb2 as spu_pb2
32+
import spu.utils.distributed as ppd
3333

3434
logging.basicConfig(level=logging.INFO)
3535

examples/python/ml/flax_llama7b/flax_llama7b.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@
1919

2020
import argparse
2121
import json
22+
from contextlib import contextmanager
23+
from typing import Any, Optional, Tuple, Union
24+
25+
import flax.linen as nn
2226
import jax
23-
import jax.numpy as jnp
2427
import jax.nn as jnn
25-
import flax.linen as nn
28+
import jax.numpy as jnp
29+
from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLM, LLaMAConfig
2630
from flax.linen.linear import Array
27-
from typing import Any, Optional, Tuple, Union
2831
from transformers import LlamaTokenizer
29-
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
30-
import spu.utils.distributed as ppd
31-
from contextlib import contextmanager
32+
3233
import spu.intrinsic as intrinsic
3334
import spu.spu_pb2 as spu_pb2
35+
import spu.utils.distributed as ppd
3436

3537
parser = argparse.ArgumentParser(description='distributed driver.')
3638
parser.add_argument("-c", "--config", default="examples/python/ml/flax_llama/3pc.json")

examples/python/ml/flax_mlp/flax_mlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ def run_on_cpu():
106106

107107
ppd.init(conf["nodes"], conf["devices"])
108108

109-
import cloudpickle as pickle
110109
import tempfile
111110

111+
import cloudpickle as pickle
112+
112113

113114
def compute_score(param, type):
114115
x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False)

examples/python/ml/flax_resnet/flax_resnet.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,17 @@
1717
# See issue #620.
1818
# pytype: disable=wrong-arg-count
1919

20-
from typing import Any
2120
import argparse
2221
import time
22+
from typing import Any
2323

24+
import jax
25+
import jax.numpy as jnp
26+
import optax
2427
import tensorflow as tf
2528
import tensorflow_datasets as tfds
26-
27-
import optax
2829
from flax.training import train_state
29-
import jax.numpy as jnp
30-
import jax
3130
from jax import random
32-
3331
from models import ResNet18
3432

3533
NUM_CLASSES = 10

examples/python/ml/flax_resnet/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from functools import partial
2121
from typing import Any, Callable, Sequence, Tuple
2222

23-
from flax import linen as nn
2423
import jax.numpy as jnp
24+
from flax import linen as nn
2525

2626
ModuleDef = Any
2727

examples/python/ml/flax_vae/flax_vae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
import optax
2222
import tensorflow as tf
2323
import tensorflow_datasets as tfds
24+
from flax import linen as nn
25+
from flax.training import train_state
2426
from jax import random
2527

2628
import examples.python.ml.flax_vae.utils as vae_utils
27-
from flax import linen as nn
28-
from flax.training import train_state
2929

3030
# Replace absl.flags used by original authors with argparse for unittest
3131
parser = argparse.ArgumentParser(description='distributed driver.')

examples/python/ml/jraph_gnn/jraph_gnn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424

2525
import logging
2626

27-
from absl import app
2827
import haiku as hk
2928
import jax
3029
import jax.numpy as jnp
3130
import jraph
3231
import optax
32+
from absl import app
3333

3434

3535
def get_zacharys_karate_club() -> jraph.GraphsTuple:
@@ -252,9 +252,10 @@ def predict(params):
252252

253253

254254
import argparse
255-
import spu.utils.distributed as ppd
256255
import json
257256

257+
import spu.utils.distributed as ppd
258+
258259
parser = argparse.ArgumentParser(description="distributed driver.")
259260
parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json")
260261
args = parser.parse_args()

examples/python/ml/ml_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,20 @@
1313
# limitations under the License.
1414

1515

16+
import inspect
1617
import json
1718
import logging
19+
import os
1820
import sys
1921
import unittest
2022
from time import perf_counter
21-
import os
2223

2324
import multiprocess
2425
import numpy.testing as npt
2526
import pandas as pd
26-
import inspect
2727

2828
import spu.utils.distributed as ppd
2929

30-
3130
with open("examples/python/conf/3pc.json", 'r') as file:
3231
conf = json.load(file)
3332

0 commit comments

Comments
 (0)