Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure transpose function works for any bigarray kind. #41

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions lib/codecs/array_to_array.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
open Codecs_intf

module Ndarray = Owl.Dense.Ndarray.Generic

(* https://zarr-specs.readthedocs.io/en/latest/v3/codecs/transpose/v1.0.html *)
module TransposeCodec = struct
Expand Down Expand Up @@ -29,7 +28,6 @@ module TransposeCodec = struct
the decoded representation dimensionality." in
Result.error @@ `Transpose_order (t, msg)


let parse_order o =
if Array.length o = 0 then
let msg = "transpose order cannot be empty." in
Expand Down Expand Up @@ -66,14 +64,24 @@ module TransposeCodec = struct
else
Ok ()

let encode o x =
try Ok (Ndarray.transpose ~axis:o x) with
| Failure s -> Error (`Transpose_order (o, s))
(* NOTE: See https://github.com/owlbarn/owl/issues/671#issuecomment-2241761001 *)
let transpose ?axis x =
let module A = Owl.Dense.Ndarray.Any in
let module N = Owl.Dense.Ndarray.Generic in
try
let y = A.transpose ?axis @@ A.init_nd (N.shape x) @@ N.get x in
Result.ok @@ N.init_nd (N.kind x) (A.shape y) @@ A.get y
with
| Assert_failure _ ->
Result.error @@
`Transpose_order (Option.get axis, "Invalid transpose order.")

let encode o x = transpose ~axis:o x

let decode o x =
let inv_order = Array.(make (length o) 0) in
Array.iteri (fun i x -> inv_order.(x) <- i) o;
Ok (Ndarray.transpose ~axis:inv_order x)
transpose ~axis:inv_order x

let to_yojson order =
let o =
Expand Down
22 changes: 4 additions & 18 deletions test/test_codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ let bytes_encode_decode
let tests = [
"test codec chain" >:: (fun _ ->
let decoded_repr
: (float, Bigarray.float32_elt) array_repr =
: (int, Bigarray.int16_signed_elt) array_repr =
{shape = [|10; 15; 10|]
;kind = Bigarray.Float32
;fill_value = (-10.)}
;kind = Bigarray.Int16_signed
;fill_value = 10}
in
let shard_cfg =
{chunk_shape = [|2; 5; 5|]
Expand Down Expand Up @@ -323,7 +323,7 @@ let tests = [
let cfg =
{chunk_shape = [|3; 5; 5|]
;index_location = Start
;index_codecs = [`Bytes LE; `Crc32c]
;index_codecs = [`Transpose [|0; 3; 1; 2|]; `Bytes LE; `Crc32c]
;codecs = [`Bytes BE]}
in
let chain = [`ShardingIndexed cfg] in
Expand Down Expand Up @@ -362,20 +362,6 @@ let tests = [
assert_failure
"Successfully encoded array should decode without fail");

(* test if including a transpose codec for index_codec chain results in
a failure. *)
let chain' =
[`ShardingIndexed {cfg with
chunk_shape = [|5; 3; 5|]
;index_codecs = `Transpose [|0; 3; 1; 2|] :: cfg.index_codecs}]
in
let cc = Chain.create decoded_repr chain' |> Result.get_ok in
assert_bool
"shard index chain can't be encoded since Owl does not support transposing
Int64 types. See:
https://github.com/owlbarn/owl/issues/671#issuecomment-2211303040" @@
Result.is_error @@ Chain.encode cc arr;

(* test correctness of decoding nested sharding codecs.*)
let str =
{|[
Expand Down
Loading