Skip to content
This repository was archived by the owner on Mar 4, 2025. It is now read-only.

Commit a5bc57e

Browse files
committed
npy: first stab at a n-dim array with support for ragged-arrays
Fixes #20. Signed-off-by: Sebastien Binet <[email protected]>
1 parent 862adbe commit a5bc57e

18 files changed

+3340
-11
lines changed

dump.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"fmt"
1010
"io"
1111
"os"
12-
"reflect"
1312
"strings"
1413

1514
"github.com/sbinet/npyio/npy"
@@ -128,15 +127,11 @@ func display(o io.Writer, f io.Reader, fname string) error {
128127

129128
fmt.Fprintf(o, "npy-header: %v\n", r.Header)
130129

131-
rt := npy.TypeFrom(r.Header.Descr.Type)
132-
if rt == nil {
133-
return fmt.Errorf("npyio: no reflect type for %q", r.Header.Descr.Type)
134-
}
135-
rv := reflect.New(reflect.SliceOf(rt))
136-
err = r.Read(rv.Interface())
130+
var arr npy.Array
131+
err = r.Read(&arr)
137132
if err != nil && err != io.EOF {
138133
return fmt.Errorf("npyio: read error: %w", err)
139134
}
140-
fmt.Fprintf(o, "data = %v\n", rv.Elem().Interface())
135+
fmt.Fprintf(o, "data = %v\n", arr.Data())
141136
return nil
142137
}

dump_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ func TestDump(t *testing.T) {
3636
name: "testdata/data_float64_forder.npz",
3737
want: "testdata/data_float64_forder.npz.txt",
3838
},
39+
{
40+
name: "testdata/ragged-array.npy",
41+
want: "testdata/ragged-array.npy.txt",
42+
},
43+
{
44+
name: "testdata/ragged-array-mixed.npy",
45+
want: "testdata/ragged-array-mixed.npy.txt",
46+
},
3947
} {
4048
t.Run(tc.name, func(t *testing.T) {
4149
f, err := os.Open(tc.name)

go.mod

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ go 1.20
44

55
require (
66
github.com/campoy/embedmd v1.0.0
7+
github.com/nlpodyssey/gopickle v0.3.0
8+
golang.org/x/text v0.14.0
79
gonum.org/v1/gonum v0.14.0
810
)
911

go.sum

+4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY=
22
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
3+
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
4+
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
35
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
46
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
57
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
8+
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
9+
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
610
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
711
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=

npy/array.go

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// Copyright 2023 The npyio Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package npy
6+
7+
import (
8+
"fmt"
9+
"strings"
10+
11+
py "github.com/nlpodyssey/gopickle/types"
12+
)
13+
14+
// Array is a multidimensional, homogeneous array of fixed-size items.
15+
type Array struct {
16+
descr ArrayDescr
17+
shape []int
18+
strides []int
19+
fortran bool
20+
21+
data any
22+
}
23+
24+
var (
25+
_ py.PyNewable = (*Array)(nil)
26+
_ py.PyStateSettable = (*Array)(nil)
27+
)
28+
29+
func (*Array) PyNew(args ...any) (any, error) {
30+
var (
31+
subtype = args[0]
32+
descr = args[1].(*ArrayDescr)
33+
shape = args[2].([]int)
34+
strides = args[3].([]int)
35+
data = args[4].([]byte)
36+
flags = args[5].(int)
37+
)
38+
39+
return newArray(subtype, *descr, shape, strides, data, flags)
40+
}
41+
42+
func newArray(subtype any, descr ArrayDescr, shape, strides []int, data []byte, flags int) (*Array, error) {
43+
switch subtype := subtype.(type) {
44+
case *Array:
45+
// ok.
46+
default:
47+
return nil, fmt.Errorf("subtyping ndarray with %T is not (yet?) supported", subtype)
48+
}
49+
50+
arr := &Array{
51+
descr: descr,
52+
shape: shape,
53+
strides: strides,
54+
data: data,
55+
}
56+
return arr, nil
57+
}
58+
59+
func (arr *Array) PySetState(arg any) error {
60+
tuple, ok := arg.(*py.Tuple)
61+
if !ok {
62+
return fmt.Errorf("invalid argument type %T", arg)
63+
}
64+
65+
var (
66+
vers = 0
67+
shape py.Tuple
68+
raw any
69+
)
70+
switch tuple.Len() {
71+
case 5:
72+
err := parseTuple(tuple, &vers, &shape, &arr.descr, &arr.fortran, nil)
73+
if err != nil {
74+
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err)
75+
}
76+
raw = tuple.Get(4)
77+
case 4:
78+
err := parseTuple(tuple, &shape, &arr.descr, &arr.fortran, nil)
79+
if err != nil {
80+
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err)
81+
}
82+
raw = tuple.Get(3)
83+
default:
84+
return fmt.Errorf("invalid length (%d) for ndarray.__setstate__ tuple", tuple.Len())
85+
}
86+
87+
arr.shape = nil
88+
for i := range shape {
89+
v, ok := shape.Get(i).(int)
90+
if !ok {
91+
return fmt.Errorf("invalid shape[%d]: got=%T, want=int", i, shape.Get(i))
92+
}
93+
arr.shape = append(arr.shape, v)
94+
}
95+
96+
err := arr.setupStrides()
97+
if err != nil {
98+
return fmt.Errorf("ndarray.__setstate__ could not infer strides: %w", err)
99+
}
100+
101+
switch raw := raw.(type) {
102+
case *py.List:
103+
arr.data = raw
104+
105+
case []byte:
106+
data, err := arr.descr.unmarshal(raw, arr.shape)
107+
if err != nil {
108+
return fmt.Errorf("ndarray.__setstate__ could not unmarshal raw data: %w", err)
109+
}
110+
arr.data = data
111+
}
112+
113+
return nil
114+
}
115+
116+
func (arr *Array) setupStrides() error {
117+
// TODO(sbinet): complete implementation.
118+
// see: _array_fill_strides in numpy/_core/multiarray/ctors.c
119+
120+
if arr.shape == nil {
121+
arr.strides = nil
122+
return nil
123+
}
124+
125+
strides := make([]int, len(arr.shape))
126+
// FIXME(sbinet): handle non-contiguous arrays
127+
// FIXME(sbinet): handle FORTRAN arrays
128+
129+
var (
130+
// notCFContig bool
131+
noDim bool // a dimension != 1 was found
132+
)
133+
134+
// check if array is both FORTRAN- and C-contiguous
135+
for _, dim := range arr.shape {
136+
if dim != 1 {
137+
if noDim {
138+
// notCFContig = true
139+
break
140+
}
141+
noDim = true
142+
}
143+
}
144+
145+
itemsize := arr.descr.itemsize()
146+
switch {
147+
case arr.fortran:
148+
for i, dim := range arr.shape {
149+
strides[i] = itemsize
150+
switch {
151+
case dim != 0:
152+
itemsize *= dim
153+
default:
154+
// notCFContig = false
155+
}
156+
}
157+
158+
default:
159+
for i := len(arr.shape) - 1; i >= 0; i-- {
160+
dim := arr.shape[i]
161+
strides[i] = itemsize
162+
switch {
163+
case dim != 0:
164+
itemsize *= dim
165+
default:
166+
// notCFContig = false
167+
}
168+
}
169+
}
170+
171+
arr.strides = strides
172+
return nil
173+
}
174+
175+
// Descr returns the array's data type descriptor.
176+
func (arr Array) Descr() ArrayDescr {
177+
return arr.descr
178+
}
179+
180+
// Shape returns the array's shape.
181+
func (arr Array) Shape() []int {
182+
return arr.shape
183+
}
184+
185+
// Strides returns the array's strides in bytes.
186+
func (arr Array) Strides() []int {
187+
return arr.strides
188+
}
189+
190+
// Fortran returns whether the array's data is stored in FORTRAN-order
191+
// (ie: column-major) instead of C-order (ie: row-major.)
192+
func (arr Array) Fortran() bool {
193+
return arr.fortran
194+
}
195+
196+
// Data returns the array's underlying data.
197+
func (arr Array) Data() any {
198+
return arr.data
199+
}
200+
201+
func (arr Array) String() string {
202+
o := new(strings.Builder)
203+
fmt.Fprintf(o, "Array{descr: %v, ", arr.descr)
204+
switch arr.shape {
205+
case nil:
206+
fmt.Fprintf(o, "shape: nil, ")
207+
default:
208+
fmt.Fprintf(o, "shape: %v, ", arr.shape)
209+
}
210+
switch arr.strides {
211+
case nil:
212+
fmt.Fprintf(o, "strides: nil, ")
213+
default:
214+
fmt.Fprintf(o, "strides: %v, ", arr.strides)
215+
}
216+
fmt.Fprintf(o, "fortran: %v, data: %+v}",
217+
arr.fortran,
218+
arr.data,
219+
)
220+
return o.String()
221+
}

npy/array_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright 2023 The npyio Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package npy
6+
7+
import (
8+
"fmt"
9+
"os"
10+
"reflect"
11+
"testing"
12+
)
13+
14+
func TestArrayStringer(t *testing.T) {
15+
f, err := os.Open("../testdata/data_float64_2x3x4_corder.npy")
16+
if err != nil {
17+
t.Fatalf("could not open testdata: %+v", err)
18+
}
19+
defer f.Close()
20+
21+
var arr Array
22+
err = Read(f, &arr)
23+
if err != nil {
24+
t.Fatalf("could not read data: %+v", err)
25+
}
26+
27+
var (
28+
want = `Array{descr: ArrayDescr{kind: 'f', order: '<', flags: 0, esize: 8, align: 8, subarr: <nil>, names: [], fields: {}, meta: map[]}, shape: [2 3 4], strides: [96 32 8], fortran: false, data: [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]}`
29+
got = fmt.Sprintf("%v", arr)
30+
)
31+
32+
if got != want {
33+
t.Fatalf("invalid array display:\ngot= %s\nwant=%s", got, want)
34+
}
35+
36+
if got, want := arr.Descr().kind, byte('f'); got != want {
37+
t.Fatalf("invalid kind: got=%c, want=%c", got, want)
38+
}
39+
40+
if got, want := arr.Shape(), []int{2, 3, 4}; !reflect.DeepEqual(got, want) {
41+
t.Fatalf("invalid shape:\ngot= %+v\nwant=%+v", got, want)
42+
}
43+
44+
if got, want := arr.Strides(), []int{96, 32, 8}; !reflect.DeepEqual(got, want) {
45+
t.Fatalf("invalid strides:\ngot= %+v\nwant=%+v", got, want)
46+
}
47+
48+
if got, want := arr.Fortran(), false; got != want {
49+
t.Fatalf("invalid fortran:\ngot= %+v\nwant=%+v", got, want)
50+
}
51+
}

0 commit comments

Comments
 (0)