648db22b |
1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under both the BSD-style license (found in the |
6 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
7 | * in the COPYING file in the root directory of this source tree). |
8 | * You may select, at your option, one of the above-listed licenses. |
9 | */ |
10 | |
11 | /*-************************************* |
12 | * Dependencies |
13 | ***************************************/ |
14 | |
15 | /* Currently relies on qsort when combining contiguous matches. This can probably |
16 | * be avoided but would require changes to the algorithm. The qsort is far from |
17 | * the bottleneck in this algorithm even for medium sized files so it's probably |
18 | * not worth trying to address */ |
19 | #include <stdlib.h> |
20 | #include <assert.h> |
21 | |
22 | #include "zstd_edist.h" |
23 | #include "mem.h" |
24 | |
25 | /*-************************************* |
26 | * Constants |
27 | ***************************************/ |
28 | |
29 | /* Just a sential for the entries of the diagonal matrix */ |
30 | #define ZSTD_EDIST_DIAG_MAX (S32)(1 << 30) |
31 | |
32 | /* How large should a snake be to be considered a 'big' snake. |
33 | * For an explanation of what a 'snake' is with respect to the |
34 | * edit distance matrix, see the linked paper in zstd_edist.h */ |
35 | #define ZSTD_EDIST_SNAKE_THRESH 20 |
36 | |
37 | /* After how many iterations should we start to use the heuristic |
38 | * based on 'big' snakes */ |
39 | #define ZSTD_EDIST_SNAKE_ITER_THRESH 200 |
40 | |
41 | /* After how many iterations should be just give up and take |
42 | * the best available edit script for this round */ |
43 | #define ZSTD_EDIST_EXPENSIVE_THRESH 1024 |
44 | |
45 | /*-************************************* |
46 | * Structures |
47 | ***************************************/ |
48 | |
49 | typedef struct { |
50 | U32 dictIdx; |
51 | U32 srcIdx; |
52 | U32 matchLength; |
53 | } ZSTD_eDist_match; |
54 | |
55 | typedef struct { |
56 | const BYTE* dict; |
57 | const BYTE* src; |
58 | size_t dictSize; |
59 | size_t srcSize; |
60 | S32* forwardDiag; /* Entries of the forward diagonal stored here */ |
61 | S32* backwardDiag; /* Entries of the backward diagonal stored here. |
62 | * Note: this buffer and the 'forwardDiag' buffer |
63 | * are contiguous. See the ZSTD_eDist_genSequences */ |
64 | ZSTD_eDist_match* matches; /* Accumulate matches of length 1 in this buffer. |
65 | * In a subsequence post-processing step, we combine |
66 | * contiguous matches. */ |
67 | U32 nbMatches; |
68 | } ZSTD_eDist_state; |
69 | |
70 | typedef struct { |
71 | S32 dictMid; /* The mid diagonal for the dictionary */ |
72 | S32 srcMid; /* The mid diagonal for the source */ |
73 | int lowUseHeuristics; /* Should we use heuristics for the low part */ |
74 | int highUseHeuristics; /* Should we use heuristics for the high part */ |
75 | } ZSTD_eDist_partition; |
76 | |
77 | /*-************************************* |
78 | * Internal |
79 | ***************************************/ |
80 | |
81 | static void ZSTD_eDist_diag(ZSTD_eDist_state* state, |
82 | ZSTD_eDist_partition* partition, |
83 | S32 dictLow, S32 dictHigh, S32 srcLow, |
84 | S32 srcHigh, int useHeuristics) |
85 | { |
86 | S32* const forwardDiag = state->forwardDiag; |
87 | S32* const backwardDiag = state->backwardDiag; |
88 | const BYTE* const dict = state->dict; |
89 | const BYTE* const src = state->src; |
90 | |
91 | S32 const diagMin = dictLow - srcHigh; |
92 | S32 const diagMax = dictHigh - srcLow; |
93 | S32 const forwardMid = dictLow - srcLow; |
94 | S32 const backwardMid = dictHigh - srcHigh; |
95 | |
96 | S32 forwardMin = forwardMid; |
97 | S32 forwardMax = forwardMid; |
98 | S32 backwardMin = backwardMid; |
99 | S32 backwardMax = backwardMid; |
100 | int odd = (forwardMid - backwardMid) & 1; |
101 | U32 iterations; |
102 | |
103 | forwardDiag[forwardMid] = dictLow; |
104 | backwardDiag[backwardMid] = dictHigh; |
105 | |
106 | /* Main loop for updating diag entries. Unless useHeuristics is |
107 | * set to false, this loop will run until it finds the minimal |
108 | * edit script */ |
109 | for (iterations = 1;;iterations++) { |
110 | S32 diag; |
111 | int bigSnake = 0; |
112 | |
113 | if (forwardMin > diagMin) { |
114 | forwardMin--; |
115 | forwardDiag[forwardMin - 1] = -1; |
116 | } else { |
117 | forwardMin++; |
118 | } |
119 | |
120 | if (forwardMax < diagMax) { |
121 | forwardMax++; |
122 | forwardDiag[forwardMax + 1] = -1; |
123 | } else { |
124 | forwardMax--; |
125 | } |
126 | |
127 | for (diag = forwardMax; diag >= forwardMin; diag -= 2) { |
128 | S32 dictIdx; |
129 | S32 srcIdx; |
130 | S32 low = forwardDiag[diag - 1]; |
131 | S32 high = forwardDiag[diag + 1]; |
132 | S32 dictIdx0 = low < high ? high : low + 1; |
133 | |
134 | for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag; |
135 | dictIdx < dictHigh && srcIdx < srcHigh && dict[dictIdx] == src[srcIdx]; |
136 | dictIdx++, srcIdx++) continue; |
137 | |
138 | if (dictIdx - dictIdx0 > ZSTD_EDIST_SNAKE_THRESH) |
139 | bigSnake = 1; |
140 | |
141 | forwardDiag[diag] = dictIdx; |
142 | |
143 | if (odd && backwardMin <= diag && diag <= backwardMax && backwardDiag[diag] <= dictIdx) { |
144 | partition->dictMid = dictIdx; |
145 | partition->srcMid = srcIdx; |
146 | partition->lowUseHeuristics = 0; |
147 | partition->highUseHeuristics = 0; |
148 | return; |
149 | } |
150 | } |
151 | |
152 | if (backwardMin > diagMin) { |
153 | backwardMin--; |
154 | backwardDiag[backwardMin - 1] = ZSTD_EDIST_DIAG_MAX; |
155 | } else { |
156 | backwardMin++; |
157 | } |
158 | |
159 | if (backwardMax < diagMax) { |
160 | backwardMax++; |
161 | backwardDiag[backwardMax + 1] = ZSTD_EDIST_DIAG_MAX; |
162 | } else { |
163 | backwardMax--; |
164 | } |
165 | |
166 | |
167 | for (diag = backwardMax; diag >= backwardMin; diag -= 2) { |
168 | S32 dictIdx; |
169 | S32 srcIdx; |
170 | S32 low = backwardDiag[diag - 1]; |
171 | S32 high = backwardDiag[diag + 1]; |
172 | S32 dictIdx0 = low < high ? low : high - 1; |
173 | |
174 | for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag; |
175 | dictLow < dictIdx && srcLow < srcIdx && dict[dictIdx - 1] == src[srcIdx - 1]; |
176 | dictIdx--, srcIdx--) continue; |
177 | |
178 | if (dictIdx0 - dictIdx > ZSTD_EDIST_SNAKE_THRESH) |
179 | bigSnake = 1; |
180 | |
181 | backwardDiag[diag] = dictIdx; |
182 | |
183 | if (!odd && forwardMin <= diag && diag <= forwardMax && dictIdx <= forwardDiag[diag]) { |
184 | partition->dictMid = dictIdx; |
185 | partition->srcMid = srcIdx; |
186 | partition->lowUseHeuristics = 0; |
187 | partition->highUseHeuristics = 0; |
188 | return; |
189 | } |
190 | } |
191 | |
192 | if (!useHeuristics) |
193 | continue; |
194 | |
195 | /* Everything under this point is a heuristic. Using these will |
196 | * substantially speed up the match finding. In some cases, taking |
197 | * the total match finding time from several minutes to seconds. |
198 | * Of course, the caveat is that the edit script found may no longer |
199 | * be optimal */ |
200 | |
201 | /* Big snake heuristic */ |
202 | if (iterations > ZSTD_EDIST_SNAKE_ITER_THRESH && bigSnake) { |
203 | { |
204 | S32 best = 0; |
205 | |
206 | for (diag = forwardMax; diag >= forwardMin; diag -= 2) { |
207 | S32 diagDiag = diag - forwardMid; |
208 | S32 dictIdx = forwardDiag[diag]; |
209 | S32 srcIdx = dictIdx - diag; |
210 | S32 v = (dictIdx - dictLow) * 2 - diagDiag; |
211 | |
212 | if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) { |
213 | if (v > best |
214 | && dictLow + ZSTD_EDIST_SNAKE_THRESH <= dictIdx && dictIdx <= dictHigh |
215 | && srcLow + ZSTD_EDIST_SNAKE_THRESH <= srcIdx && srcIdx <= srcHigh) { |
216 | S32 k; |
217 | for (k = 1; dict[dictIdx - k] == src[srcIdx - k]; k++) { |
218 | if (k == ZSTD_EDIST_SNAKE_THRESH) { |
219 | best = v; |
220 | partition->dictMid = dictIdx; |
221 | partition->srcMid = srcIdx; |
222 | break; |
223 | } |
224 | } |
225 | } |
226 | } |
227 | } |
228 | |
229 | if (best > 0) { |
230 | partition->lowUseHeuristics = 0; |
231 | partition->highUseHeuristics = 1; |
232 | return; |
233 | } |
234 | } |
235 | |
236 | { |
237 | S32 best = 0; |
238 | |
239 | for (diag = backwardMax; diag >= backwardMin; diag -= 2) { |
240 | S32 diagDiag = diag - backwardMid; |
241 | S32 dictIdx = backwardDiag[diag]; |
242 | S32 srcIdx = dictIdx - diag; |
243 | S32 v = (dictHigh - dictIdx) * 2 + diagDiag; |
244 | |
245 | if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) { |
246 | if (v > best |
247 | && dictLow < dictIdx && dictIdx <= dictHigh - ZSTD_EDIST_SNAKE_THRESH |
248 | && srcLow < srcIdx && srcIdx <= srcHigh - ZSTD_EDIST_SNAKE_THRESH) { |
249 | int k; |
250 | for (k = 0; dict[dictIdx + k] == src[srcIdx + k]; k++) { |
251 | if (k == ZSTD_EDIST_SNAKE_THRESH - 1) { |
252 | best = v; |
253 | partition->dictMid = dictIdx; |
254 | partition->srcMid = srcIdx; |
255 | break; |
256 | } |
257 | } |
258 | } |
259 | } |
260 | } |
261 | |
262 | if (best > 0) { |
263 | partition->lowUseHeuristics = 1; |
264 | partition->highUseHeuristics = 0; |
265 | return; |
266 | } |
267 | } |
268 | } |
269 | |
270 | /* More general 'too expensive' heuristic */ |
271 | if (iterations >= ZSTD_EDIST_EXPENSIVE_THRESH) { |
272 | S32 forwardDictSrcBest; |
273 | S32 forwardDictBest = 0; |
274 | S32 backwardDictSrcBest; |
275 | S32 backwardDictBest = 0; |
276 | |
277 | forwardDictSrcBest = -1; |
278 | for (diag = forwardMax; diag >= forwardMin; diag -= 2) { |
279 | S32 dictIdx = MIN(forwardDiag[diag], dictHigh); |
280 | S32 srcIdx = dictIdx - diag; |
281 | |
282 | if (srcHigh < srcIdx) { |
283 | dictIdx = srcHigh + diag; |
284 | srcIdx = srcHigh; |
285 | } |
286 | |
287 | if (forwardDictSrcBest < dictIdx + srcIdx) { |
288 | forwardDictSrcBest = dictIdx + srcIdx; |
289 | forwardDictBest = dictIdx; |
290 | } |
291 | } |
292 | |
293 | backwardDictSrcBest = ZSTD_EDIST_DIAG_MAX; |
294 | for (diag = backwardMax; diag >= backwardMin; diag -= 2) { |
295 | S32 dictIdx = MAX(dictLow, backwardDiag[diag]); |
296 | S32 srcIdx = dictIdx - diag; |
297 | |
298 | if (srcIdx < srcLow) { |
299 | dictIdx = srcLow + diag; |
300 | srcIdx = srcLow; |
301 | } |
302 | |
303 | if (dictIdx + srcIdx < backwardDictSrcBest) { |
304 | backwardDictSrcBest = dictIdx + srcIdx; |
305 | backwardDictBest = dictIdx; |
306 | } |
307 | } |
308 | |
309 | if ((dictHigh + srcHigh) - backwardDictSrcBest < forwardDictSrcBest - (dictLow + srcLow)) { |
310 | partition->dictMid = forwardDictBest; |
311 | partition->srcMid = forwardDictSrcBest - forwardDictBest; |
312 | partition->lowUseHeuristics = 0; |
313 | partition->highUseHeuristics = 1; |
314 | } else { |
315 | partition->dictMid = backwardDictBest; |
316 | partition->srcMid = backwardDictSrcBest - backwardDictBest; |
317 | partition->lowUseHeuristics = 1; |
318 | partition->highUseHeuristics = 0; |
319 | } |
320 | return; |
321 | } |
322 | } |
323 | } |
324 | |
325 | static void ZSTD_eDist_insertMatch(ZSTD_eDist_state* state, |
326 | S32 const dictIdx, S32 const srcIdx) |
327 | { |
328 | state->matches[state->nbMatches].dictIdx = dictIdx; |
329 | state->matches[state->nbMatches].srcIdx = srcIdx; |
330 | state->matches[state->nbMatches].matchLength = 1; |
331 | state->nbMatches++; |
332 | } |
333 | |
334 | static int ZSTD_eDist_compare(ZSTD_eDist_state* state, |
335 | S32 dictLow, S32 dictHigh, S32 srcLow, |
336 | S32 srcHigh, int useHeuristics) |
337 | { |
338 | const BYTE* const dict = state->dict; |
339 | const BYTE* const src = state->src; |
340 | |
341 | /* Found matches while traversing from the low end */ |
342 | while (dictLow < dictHigh && srcLow < srcHigh && dict[dictLow] == src[srcLow]) { |
343 | ZSTD_eDist_insertMatch(state, dictLow, srcLow); |
344 | dictLow++; |
345 | srcLow++; |
346 | } |
347 | |
348 | /* Found matches while traversing from the high end */ |
349 | while (dictLow < dictHigh && srcLow < srcHigh && dict[dictHigh - 1] == src[srcHigh - 1]) { |
350 | ZSTD_eDist_insertMatch(state, dictHigh - 1, srcHigh - 1); |
351 | dictHigh--; |
352 | srcHigh--; |
353 | } |
354 | |
355 | /* If the low and high end end up touching. If we wanted to make |
356 | * note of the differences like most diffing algorithms do, we would |
357 | * do so here. In our case, we're only concerned with matches |
358 | * Note: if you wanted to find the edit distance of the algorithm, |
359 | * you could just accumulate the cost for an insertion/deletion |
360 | * below. */ |
361 | if (dictLow == dictHigh) { |
362 | while (srcLow < srcHigh) { |
363 | /* Reaching this point means inserting src[srcLow] into |
364 | * the current position of dict */ |
365 | srcLow++; |
366 | } |
367 | } else if (srcLow == srcHigh) { |
368 | while (dictLow < dictHigh) { |
369 | /* Reaching this point means deleting dict[dictLow] from |
370 | * the current position of dict */ |
371 | dictLow++; |
372 | } |
373 | } else { |
374 | ZSTD_eDist_partition partition; |
375 | partition.dictMid = 0; |
376 | partition.srcMid = 0; |
377 | ZSTD_eDist_diag(state, &partition, dictLow, dictHigh, |
378 | srcLow, srcHigh, useHeuristics); |
379 | if (ZSTD_eDist_compare(state, dictLow, partition.dictMid, |
380 | srcLow, partition.srcMid, partition.lowUseHeuristics)) |
381 | return 1; |
382 | if (ZSTD_eDist_compare(state, partition.dictMid, dictHigh, |
383 | partition.srcMid, srcHigh, partition.highUseHeuristics)) |
384 | return 1; |
385 | } |
386 | |
387 | return 0; |
388 | } |
389 | |
390 | static int ZSTD_eDist_matchComp(const void* p, const void* q) |
391 | { |
392 | S32 const l = ((ZSTD_eDist_match*)p)->srcIdx; |
393 | S32 const r = ((ZSTD_eDist_match*)q)->srcIdx; |
394 | return (l - r); |
395 | } |
396 | |
397 | /* The matches from the approach above will all be of the form |
398 | * (dictIdx, srcIdx, 1). This method combines contiguous matches |
399 | * of length MINMATCH or greater. Matches less than MINMATCH |
400 | * are discarded */ |
401 | static void ZSTD_eDist_combineMatches(ZSTD_eDist_state* state) |
402 | { |
403 | /* Create a new buffer to put the combined matches into |
404 | * and memcpy to state->matches after */ |
405 | ZSTD_eDist_match* combinedMatches = |
406 | ZSTD_malloc(state->nbMatches * sizeof(ZSTD_eDist_match), |
407 | ZSTD_defaultCMem); |
408 | |
409 | U32 nbCombinedMatches = 1; |
410 | size_t i; |
411 | |
412 | /* Make sure that the srcIdx and dictIdx are in sorted order. |
413 | * The combination step won't work otherwise */ |
414 | qsort(state->matches, state->nbMatches, sizeof(ZSTD_eDist_match), ZSTD_eDist_matchComp); |
415 | |
416 | memcpy(combinedMatches, state->matches, sizeof(ZSTD_eDist_match)); |
417 | for (i = 1; i < state->nbMatches; i++) { |
418 | ZSTD_eDist_match const match = state->matches[i]; |
419 | ZSTD_eDist_match const combinedMatch = |
420 | combinedMatches[nbCombinedMatches - 1]; |
421 | if (combinedMatch.srcIdx + combinedMatch.matchLength == match.srcIdx && |
422 | combinedMatch.dictIdx + combinedMatch.matchLength == match.dictIdx) { |
423 | combinedMatches[nbCombinedMatches - 1].matchLength++; |
424 | } else { |
425 | /* Discard matches that are less than MINMATCH */ |
426 | if (combinedMatches[nbCombinedMatches - 1].matchLength < MINMATCH) { |
427 | nbCombinedMatches--; |
428 | } |
429 | |
430 | memcpy(combinedMatches + nbCombinedMatches, |
431 | state->matches + i, sizeof(ZSTD_eDist_match)); |
432 | nbCombinedMatches++; |
433 | } |
434 | } |
435 | memcpy(state->matches, combinedMatches, nbCombinedMatches * sizeof(ZSTD_eDist_match)); |
436 | state->nbMatches = nbCombinedMatches; |
437 | ZSTD_free(combinedMatches, ZSTD_defaultCMem); |
438 | } |
439 | |
440 | static size_t ZSTD_eDist_convertMatchesToSequences(ZSTD_Sequence* sequences, |
441 | ZSTD_eDist_state* state) |
442 | { |
443 | const ZSTD_eDist_match* matches = state->matches; |
444 | size_t const nbMatches = state->nbMatches; |
445 | size_t const dictSize = state->dictSize; |
446 | size_t nbSequences = 0; |
447 | size_t i; |
448 | for (i = 0; i < nbMatches; i++) { |
449 | ZSTD_eDist_match const match = matches[i]; |
450 | U32 const litLength = !i ? match.srcIdx : |
451 | match.srcIdx - (matches[i - 1].srcIdx + matches[i - 1].matchLength); |
452 | U32 const offset = (match.srcIdx + dictSize) - match.dictIdx; |
453 | U32 const matchLength = match.matchLength; |
454 | sequences[nbSequences].offset = offset; |
455 | sequences[nbSequences].litLength = litLength; |
456 | sequences[nbSequences].matchLength = matchLength; |
457 | nbSequences++; |
458 | } |
459 | return nbSequences; |
460 | } |
461 | |
462 | /*-************************************* |
463 | * Internal utils |
464 | ***************************************/ |
465 | |
466 | static size_t ZSTD_eDist_hamingDist(const BYTE* const a, |
467 | const BYTE* const b, size_t n) |
468 | { |
469 | size_t i; |
470 | size_t dist = 0; |
471 | for (i = 0; i < n; i++) |
472 | dist += a[i] != b[i]; |
473 | return dist; |
474 | } |
475 | |
476 | /* This is a pretty naive recursive implementation that should only |
477 | * be used for quick tests obviously. Don't try and run this on a |
478 | * GB file or something. There are faster implementations. Use those |
479 | * if you need to run it for large files. */ |
480 | static size_t ZSTD_eDist_levenshteinDist(const BYTE* const s, |
481 | size_t const sn, const BYTE* const t, |
482 | size_t const tn) |
483 | { |
484 | size_t a, b, c; |
485 | |
486 | if (!sn) |
487 | return tn; |
488 | if (!tn) |
489 | return sn; |
490 | |
491 | if (s[sn - 1] == t[tn - 1]) |
492 | return ZSTD_eDist_levenshteinDist( |
493 | s, sn - 1, t, tn - 1); |
494 | |
495 | a = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn - 1); |
496 | b = ZSTD_eDist_levenshteinDist(s, sn, t, tn - 1); |
497 | c = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn); |
498 | |
499 | if (a > b) |
500 | a = b; |
501 | if (a > c) |
502 | a = c; |
503 | |
504 | return a + 1; |
505 | } |
506 | |
507 | static void ZSTD_eDist_validateMatches(ZSTD_eDist_match* matches, |
508 | size_t const nbMatches, const BYTE* const dict, |
509 | size_t const dictSize, const BYTE* const src, |
510 | size_t const srcSize) |
511 | { |
512 | size_t i; |
513 | for (i = 0; i < nbMatches; i++) { |
514 | ZSTD_eDist_match match = matches[i]; |
515 | U32 const dictIdx = match.dictIdx; |
516 | U32 const srcIdx = match.srcIdx; |
517 | U32 const matchLength = match.matchLength; |
518 | |
519 | assert(dictIdx + matchLength < dictSize); |
520 | assert(srcIdx + matchLength < srcSize); |
521 | assert(!memcmp(dict + dictIdx, src + srcIdx, matchLength)); |
522 | } |
523 | } |
524 | |
525 | /*-************************************* |
526 | * API |
527 | ***************************************/ |
528 | |
529 | size_t ZSTD_eDist_genSequences(ZSTD_Sequence* sequences, |
530 | const void* dict, size_t dictSize, |
531 | const void* src, size_t srcSize, |
532 | int useHeuristics) |
533 | { |
534 | size_t const nbDiags = dictSize + srcSize + 3; |
535 | S32* buffer = ZSTD_malloc(nbDiags * 2 * sizeof(S32), ZSTD_defaultCMem); |
536 | ZSTD_eDist_state state; |
537 | size_t nbSequences = 0; |
538 | |
539 | state.dict = (const BYTE*)dict; |
540 | state.src = (const BYTE*)src; |
541 | state.dictSize = dictSize; |
542 | state.srcSize = srcSize; |
543 | state.forwardDiag = buffer; |
544 | state.backwardDiag = buffer + nbDiags; |
545 | state.forwardDiag += srcSize + 1; |
546 | state.backwardDiag += srcSize + 1; |
547 | state.matches = ZSTD_malloc(srcSize * sizeof(ZSTD_eDist_match), ZSTD_defaultCMem); |
548 | state.nbMatches = 0; |
549 | |
550 | ZSTD_eDist_compare(&state, 0, dictSize, 0, srcSize, 1); |
551 | ZSTD_eDist_combineMatches(&state); |
552 | nbSequences = ZSTD_eDist_convertMatchesToSequences(sequences, &state); |
553 | |
554 | ZSTD_free(buffer, ZSTD_defaultCMem); |
555 | ZSTD_free(state.matches, ZSTD_defaultCMem); |
556 | |
557 | return nbSequences; |
558 | } |