44import pytest
55
66from mlem .contrib .lightgbm import (
7+ LIGHTGBM_DATA ,
8+ LIGHTGBM_LABEL ,
79 LightGBMDataReader ,
810 LightGBMDataType ,
911 LightGBMDataWriter ,
1214from mlem .contrib .numpy import NumpyNdarrayType
1315from mlem .contrib .pandas import DataFrameType
1416from mlem .core .artifacts import LOCAL_STORAGE
15- from mlem .core .data_type import DataAnalyzer , DataType
17+ from mlem .core .data_type import (
18+ ArrayType ,
19+ DataAnalyzer ,
20+ DataType ,
21+ PrimitiveType ,
22+ )
1623from mlem .core .errors import DeserializationError , SerializationError
1724from mlem .core .model import ModelAnalyzer , ModelType
1825from mlem .core .requirements import UnixPackageRequirement
@@ -46,7 +53,7 @@ def df_payload():
4653def data_df (df_payload ):
4754 return lgb .Dataset (
4855 df_payload ,
49- label = np .array ([0 , 1 ]). tolist () ,
56+ label = np .array ([0 , 1 ]),
5057 free_raw_data = False ,
5158 )
5259
@@ -75,6 +82,8 @@ def test_hook_np(dtype_np: DataType):
7582 assert set (dtype_np .get_requirements ().modules ) == {"lightgbm" , "numpy" }
7683 assert isinstance (dtype_np , LightGBMDataType )
7784 assert isinstance (dtype_np .inner , NumpyNdarrayType )
85+ assert isinstance (dtype_np .labels , ArrayType )
86+ assert dtype_np .labels .dtype == PrimitiveType (data = None , ptype = "float" )
7887 assert dtype_np .get_model ().__name__ == dtype_np .inner .get_model ().__name__
7988 assert dtype_np .get_model ().schema () == {
8089 "title" : "NumpyNdarray" ,
@@ -92,6 +101,7 @@ def test_hook_df(dtype_df: DataType):
92101 assert set (dtype_df .get_requirements ().modules ) == {"lightgbm" , "pandas" }
93102 assert isinstance (dtype_df , LightGBMDataType )
94103 assert isinstance (dtype_df .inner , DataFrameType )
104+ assert isinstance (dtype_df .labels , NumpyNdarrayType )
95105 assert dtype_df .get_model ().__name__ == dtype_df .inner .get_model ().__name__
96106 assert dtype_df .get_model ().schema () == {
97107 "title" : "DataFrame" ,
@@ -116,54 +126,131 @@ def test_hook_df(dtype_df: DataType):
116126
117127
118128@pytest .mark .parametrize (
119- "lgb_dtype, data_type" ,
120- [("dtype_np" , NumpyNdarrayType ), ("dtype_df" , DataFrameType )],
129+ "lgb_dtype, data_type, label_type" ,
130+ [
131+ ("dtype_np" , NumpyNdarrayType , ArrayType ),
132+ ("dtype_df" , DataFrameType , NumpyNdarrayType ),
133+ ],
121134)
122- def test_lightgbm_source (lgb_dtype , data_type , request ):
135+ def test_lightgbm_source (lgb_dtype , data_type , label_type , request ):
123136 lgb_dtype = request .getfixturevalue (lgb_dtype )
124137 assert isinstance (lgb_dtype , LightGBMDataType )
125138 assert isinstance (lgb_dtype .inner , data_type )
139+ assert isinstance (lgb_dtype .labels , label_type )
126140
127141 def custom_assert (x , y ):
128142 assert hasattr (x , "data" )
129143 assert hasattr (y , "data" )
130144 assert all (x .data == y .data )
131- assert all (x .label == y .label )
145+ label_check = x .label == y .label
146+ if isinstance (label_check , (list , np .ndarray )):
147+ assert all (label_check )
148+ else :
149+ assert label_check
132150
133- data_write_read_check (
151+ artifacts = data_write_read_check (
134152 lgb_dtype ,
135153 writer = LightGBMDataWriter (),
136154 reader_type = LightGBMDataReader ,
137155 custom_assert = custom_assert ,
138156 )
139157
158+ if isinstance (lgb_dtype .inner , NumpyNdarrayType ):
159+ assert list (artifacts .keys ()) == [
160+ f"{ LIGHTGBM_DATA } /data" ,
161+ f"{ LIGHTGBM_LABEL } /0/data" ,
162+ f"{ LIGHTGBM_LABEL } /1/data" ,
163+ f"{ LIGHTGBM_LABEL } /2/data" ,
164+ f"{ LIGHTGBM_LABEL } /3/data" ,
165+ f"{ LIGHTGBM_LABEL } /4/data" ,
166+ ]
167+ assert artifacts [f"{ LIGHTGBM_DATA } /data" ].uri .endswith (
168+ f"data/{ LIGHTGBM_DATA } "
169+ )
170+ assert artifacts [f"{ LIGHTGBM_LABEL } /0/data" ].uri .endswith (
171+ f"data/{ LIGHTGBM_LABEL } /0"
172+ )
173+ assert artifacts [f"{ LIGHTGBM_LABEL } /1/data" ].uri .endswith (
174+ f"data/{ LIGHTGBM_LABEL } /1"
175+ )
176+ assert artifacts [f"{ LIGHTGBM_LABEL } /2/data" ].uri .endswith (
177+ f"data/{ LIGHTGBM_LABEL } /2"
178+ )
179+ assert artifacts [f"{ LIGHTGBM_LABEL } /3/data" ].uri .endswith (
180+ f"data/{ LIGHTGBM_LABEL } /3"
181+ )
182+ assert artifacts [f"{ LIGHTGBM_LABEL } /4/data" ].uri .endswith (
183+ f"data/{ LIGHTGBM_LABEL } /4"
184+ )
185+ else :
186+ assert list (artifacts .keys ()) == [
187+ f"{ LIGHTGBM_DATA } /data" ,
188+ f"{ LIGHTGBM_LABEL } /data" ,
189+ ]
190+ assert artifacts [f"{ LIGHTGBM_DATA } /data" ].uri .endswith (
191+ f"data/{ LIGHTGBM_DATA } "
192+ )
193+ assert artifacts [f"{ LIGHTGBM_LABEL } /data" ].uri .endswith (
194+ f"data/{ LIGHTGBM_LABEL } "
195+ )
196+
140197
141198def test_serialize__np (dtype_np , np_payload ):
142- ds = lgb .Dataset (np_payload )
199+ ds = lgb .Dataset (np_payload , label = np_payload . reshape (( - 1 ,)). tolist () )
143200 payload = dtype_np .serialize (ds )
144- assert payload == np_payload .tolist ()
201+ assert payload [LIGHTGBM_DATA ] == np_payload .tolist ()
202+ assert payload [LIGHTGBM_LABEL ] == np_payload .reshape ((- 1 ,)).tolist ()
145203
146204 with pytest .raises (SerializationError ):
147205 dtype_np .serialize ({"abc" : 123 }) # wrong type
148206
149207
150208def test_deserialize__np (dtype_np , np_payload ):
151- ds = dtype_np .deserialize (np_payload )
209+ ds = dtype_np .deserialize (
210+ {
211+ LIGHTGBM_DATA : np_payload ,
212+ LIGHTGBM_LABEL : np_payload .reshape ((- 1 ,)).tolist (),
213+ }
214+ )
152215 assert isinstance (ds , lgb .Dataset )
153216 assert np .all (ds .data == np_payload )
217+ assert np .all (ds .label == np_payload .reshape ((- 1 ,)).tolist ())
154218
155219 with pytest .raises (DeserializationError ):
156- dtype_np .deserialize ([[1 ], ["abc" ]]) # illegal matrix
220+ dtype_np .deserialize ({ LIGHTGBM_DATA : [[1 ], ["abc" ]]} ) # illegal matrix
157221
158222
159- def test_serialize__df (dtype_df , df_payload ):
160- ds = lgb .Dataset (df_payload )
161- payload = dtype_df .serialize (ds )
162- assert payload ["values" ] == df_payload .to_dict ("records" )
223+ def test_serialize__df (df_payload ):
224+ ds = lgb .Dataset (df_payload , label = None , free_raw_data = False )
225+ payload = DataType .create (obj = ds )
226+ assert payload .serialize (ds )["values" ] == df_payload .to_dict ("records" )
227+ assert LIGHTGBM_LABEL not in payload
228+
229+ def custom_assert (x , y ):
230+ assert hasattr (x , "data" )
231+ assert hasattr (y , "data" )
232+ assert all (x .data == y .data )
233+ assert x .label == y .label
234+
235+ artifacts = data_write_read_check (
236+ payload ,
237+ writer = LightGBMDataWriter (),
238+ reader_type = LightGBMDataReader ,
239+ custom_assert = custom_assert ,
240+ )
241+
242+ assert len (artifacts .keys ()) == 1
243+ assert list (artifacts .keys ()) == ["data" ]
244+ assert artifacts ["data" ].uri .endswith ("/data" )
163245
164246
165247def test_deserialize__df (dtype_df , df_payload ):
166- ds = dtype_df .deserialize ({"values" : df_payload })
248+ ds = dtype_df .deserialize (
249+ {
250+ LIGHTGBM_DATA : {"values" : df_payload },
251+ LIGHTGBM_LABEL : np .array ([0 , 1 ]).tolist (),
252+ }
253+ )
167254 assert isinstance (ds , lgb .Dataset )
168255 assert ds .data .equals (df_payload )
169256
0 commit comments