-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mahjong] Reduce compilation/run time #1070
Comments
sotetsuk
changed the title
[Mahjong] Reduce compilation time
[Mahjong] Reduce compilation/run time
Nov 8, 2023
benchmark.pyfrom pgx.mahjong._env import (
Mahjong,
_discard,
_selfkan,
_riichi,
_tsumo,
_ron,
_pon,
_minkan,
_pass,
)
import jax
import time
import sys
import timeit
# func(state, action)
functions1 = {"_discard": _discard, "_selfkan": _selfkan}
# func(state)
functions2 = {
"_riichi": _riichi,
"_tsumo": _tsumo,
"_ron": _ron,
"_pon": _pon,
"_minkan": _minkan,
"_pass": _pass,
}
env = Mahjong()
N = 10
func_name = sys.argv[1]
if func_name in functions1:
func = functions1[func_name]
key = jax.random.PRNGKey(352)
state = env.init(key=key)
time_sta = time.perf_counter()
jax.jit(func)(state, 0)
time_end = time.perf_counter()
delta = (time_end - time_sta) * 1000
exp = jax.make_jaxpr(func)(state, 0)
n_line = len(str(exp).split("\n"))
jit_func = jax.jit(func)
run_delta = timeit.timeit(
"jit_func(state, 0)", globals=globals(), number=N
)
print(
f"| `{func.__name__}` | {n_line} | {delta:.1f}ms | {run_delta/N*1000000:.1f}μs |"
)
elif func_name in functions2:
func = functions2[func_name]
key = jax.random.PRNGKey(352)
state = env.init(key=key)
time_sta = time.perf_counter()
jax.jit(func)(state)
time_end = time.perf_counter()
delta = (time_end - time_sta) * 1000
exp = jax.make_jaxpr(func)(state)
n_line = len(str(exp).split("\n"))
jit_func = jax.jit(func)
run_delta = timeit.timeit("jit_func(state)", globals=globals(), number=N)
print(
f"| `{func.__name__}` | {n_line} | {delta:.1f}ms | {run_delta/N*1000000:.1f}μs |"
)
benchmark.shecho "| function name | # expr lines | compile time | running time |"
echo "| :--- | ---: | ---: | ---: |"
for funcname in _discard _selfkan _riichi _tsumo _ron _pon _minkan _pass
do
python3 benchmark.py $funcname
done
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
benchmark.py
benchmark.sh
step
内の主要関数抜粋_discard
_selfkan
_riichi
_tsumo
_ron
_pon
_minkan
_pass
The text was updated successfully, but these errors were encountered: