Skip to content

Commit

Permalink
Ensure transpose function works for any bigarray kind.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed Jul 22, 2024
1 parent 20af412 commit c586ba2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 24 deletions.
19 changes: 13 additions & 6 deletions lib/codecs/array_to_array.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
open Codecs_intf

module Ndarray = Owl.Dense.Ndarray.Generic

(* https://zarr-specs.readthedocs.io/en/latest/v3/codecs/transpose/v1.0.html *)
module TransposeCodec = struct
Expand Down Expand Up @@ -29,7 +28,6 @@ module TransposeCodec = struct
the decoded representation dimensionality." in
Result.error @@ `Transpose_order (t, msg)


let parse_order o =
if Array.length o = 0 then
let msg = "transpose order cannot be empty." in
Expand Down Expand Up @@ -66,14 +64,23 @@ module TransposeCodec = struct
else
Ok ()

let encode o x =
try Ok (Ndarray.transpose ~axis:o x) with
| Failure s -> Error (`Transpose_order (o, s))
let transpose ?axis x =
let module A = Owl.Dense.Ndarray.Any in
let module N = Owl.Dense.Ndarray.Generic in
try
let y = A.transpose ?axis @@ A.init_nd (N.shape x) @@ N.get x in
Result.ok @@ N.init_nd (N.kind x) (A.shape y) @@ A.get y
with
| Assert_failure _ ->
Result.error @@
`Transpose_order (Option.get axis, "Invalid transpose order.")

let encode o x = transpose ~axis:o x

let decode o x =
let inv_order = Array.(make (length o) 0) in
Array.iteri (fun i x -> inv_order.(x) <- i) o;
Ok (Ndarray.transpose ~axis:inv_order x)
transpose ~axis:inv_order x

let to_yojson order =
let o =
Expand Down
22 changes: 4 additions & 18 deletions test/test_codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ let bytes_encode_decode
let tests = [
"test codec chain" >:: (fun _ ->
let decoded_repr
: (float, Bigarray.float32_elt) array_repr =
: (int, Bigarray.int16_signed_elt) array_repr =
{shape = [|10; 15; 10|]
;kind = Bigarray.Float32
;fill_value = (-10.)}
;kind = Bigarray.Int16_signed
;fill_value = 10}
in
let shard_cfg =
{chunk_shape = [|2; 5; 5|]
Expand Down Expand Up @@ -323,7 +323,7 @@ let tests = [
let cfg =
{chunk_shape = [|3; 5; 5|]
;index_location = Start
;index_codecs = [`Bytes LE; `Crc32c]
;index_codecs = [`Transpose [|0; 3; 1; 2|]; `Bytes LE; `Crc32c]
;codecs = [`Bytes BE]}
in
let chain = [`ShardingIndexed cfg] in
Expand Down Expand Up @@ -362,20 +362,6 @@ let tests = [
assert_failure
"Successfully encoded array should decode without fail");

(* test if including a transpose codec for index_codec chain results in
a failure. *)
let chain' =
[`ShardingIndexed {cfg with
chunk_shape = [|5; 3; 5|]
;index_codecs = `Transpose [|0; 3; 1; 2|] :: cfg.index_codecs}]
in
let cc = Chain.create decoded_repr chain' |> Result.get_ok in
assert_bool
"shard index chain can't be encoded since Owl does not support transposing
Int64 types. See:
https://github.com/owlbarn/owl/issues/671#issuecomment-2211303040" @@
Result.is_error @@ Chain.encode cc arr;

(* test correctness of decoding nested sharding codecs.*)
let str =
{|[
Expand Down

0 comments on commit c586ba2

Please sign in to comment.