Skip to content

Commit 4d82ab5

Browse files
yuvaltassacopybara-github
authored andcommitted
Add function to merge chains of indices.
PiperOrigin-RevId: 712863690 Change-Id: I15652bec03dc9ce90e230788bd12b44b1a2b8217
1 parent 1c69e64 commit 4d82ab5

File tree

3 files changed

+83
-11
lines changed

3 files changed

+83
-11
lines changed

src/engine/engine_util_sparse.c

-11
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,6 @@ static void mju_addToSclScl(mjtNum* res, const mjtNum* vec, mjtNum scl1, mjtNum
206206

207207

208208

209-
// return 1 if vec1==vec2, 0 otherwise
210-
static int mju_compare(const int* vec1, const int* vec2, int n) {
211-
#ifdef mjUSEAVX
212-
return mju_compare_avx(vec1, vec2, n);
213-
#else
214-
return !memcmp(vec1, vec2, n*sizeof(int));
215-
#endif // mjUSEAVX
216-
}
217-
218-
219-
220209
// count the number of non-zeros in the sum of two sparse vectors
221210
int mju_combineSparseCount(int a_nnz, int b_nnz, const int* a_ind, const int* b_ind) {
222211
int a = 0, b = 0, c_nnz = 0;

src/engine/engine_util_sparse.h

+65
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#ifndef MUJOCO_SRC_ENGINE_ENGINE_UTIL_SPARSE_H_
1616
#define MUJOCO_SRC_ENGINE_ENGINE_UTIL_SPARSE_H_
1717

18+
#include <string.h>
19+
1820
#include <mujoco/mjdata.h>
1921
#include <mujoco/mjexport.h>
2022
#include <mujoco/mjtnum.h>
@@ -162,6 +164,69 @@ mjtNum mju_dotSparse(const mjtNum* vec1, const mjtNum* vec2, int nnz1, const int
162164
#endif // mjUSEAVX
163165
}
164166

167+
// return 1 if vec1==vec2, 0 otherwise
168+
static inline
169+
int mju_compare(const int* vec1, const int* vec2, int n) {
170+
#ifdef mjUSEAVX
171+
return mju_compare_avx(vec1, vec2, n);
172+
#else
173+
return !memcmp(vec1, vec2, n*sizeof(int));
174+
#endif // mjUSEAVX
175+
}
176+
177+
178+
// merge unique sorted integers, merge array must be large enough (not checked for)
179+
static inline
180+
int mj_mergeSorted(int* merge, const int* chain1, int n1, const int* chain2, int n2) {
181+
// special case: one or both empty
182+
if (n1 == 0) {
183+
if (n2 == 0) {
184+
return 0;
185+
}
186+
memcpy(merge, chain2, n2 * sizeof(int));
187+
return n2;
188+
} else if (n2 == 0) {
189+
memcpy(merge, chain1, n1 * sizeof(int));
190+
return n1;
191+
}
192+
193+
// special case: identical pattern
194+
if (n1 == n2 && mju_compare(chain1, chain2, n1)) {
195+
memcpy(merge, chain1, n1 * sizeof(int));
196+
return n1;
197+
}
198+
199+
// merge while both chains are non-empty
200+
int i = 0, j = 0, k = 0;
201+
while (i < n1 && j < n2) {
202+
int c1 = chain1[i];
203+
int c2 = chain2[j];
204+
205+
if (c1 < c2) {
206+
merge[k++] = c1;
207+
i++;
208+
} else if (c1 > c2) {
209+
merge[k++] = c2;
210+
j++;
211+
} else { // c1 == c2
212+
merge[k++] = c1;
213+
i++;
214+
j++;
215+
}
216+
}
217+
218+
// copy remaining
219+
if (i < n1) {
220+
memcpy(merge + k, chain1 + i, (n1 - i)*sizeof(int));
221+
k += n1 - i;
222+
} else if (j < n2) {
223+
memcpy(merge + k, chain2 + j, (n2 - j)*sizeof(int));
224+
k += n2 - j;
225+
}
226+
227+
return k;
228+
}
229+
165230

166231
#ifdef __cplusplus
167232
}

test/engine/engine_util_sparse_test.cc

+18
Original file line numberDiff line numberDiff line change
@@ -1100,5 +1100,23 @@ TEST_F(EngineUtilSparseTest, MjuDenseToSparse) {
11001100
EXPECT_EQ(status0, 1);
11011101
}
11021102

1103+
TEST_F(EngineUtilSparseTest, MergeSorted) {
1104+
const int chain1_a[] = {1, 2, 3};
1105+
const int chain2_a[] = {};
1106+
int merged_a[3];
1107+
int n1 = 3;
1108+
int n2 = 0;
1109+
EXPECT_EQ(mj_mergeSorted(merged_a, chain1_a, n1, chain2_a, n2), 3);
1110+
EXPECT_THAT(merged_a, ElementsAre(1, 2, 3));
1111+
1112+
const int chain1_b[] = {1, 3, 5, 7, 8};
1113+
const int chain2_b[] = {2, 4, 5, 6, 8};
1114+
int merged_b[8];
1115+
n1 = 5;
1116+
n2 = 5;
1117+
EXPECT_EQ(mj_mergeSorted(merged_b, chain1_b, n1, chain2_b, n2), 8);
1118+
EXPECT_THAT(merged_b, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8));
1119+
}
1120+
11031121
} // namespace
11041122
} // namespace mujoco

0 commit comments

Comments
 (0)