2 * Copyright (c) Meta Platforms, Inc. and affiliates.
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 * You may select, at your option, one of the above-listed licenses.
11 /*-*************************************
13 ***************************************/
15 /* Currently relies on qsort when combining contiguous matches. This can probably
16 * be avoided but would require changes to the algorithm. The qsort is far from
17 * the bottleneck in this algorithm even for medium sized files so it's probably
18 * not worth trying to address */
22 #include "zstd_edist.h"
25 /*-*************************************
27 ***************************************/
29 /* Just a sential for the entries of the diagonal matrix */
30 #define ZSTD_EDIST_DIAG_MAX (S32)(1 << 30)
32 /* How large should a snake be to be considered a 'big' snake.
33 * For an explanation of what a 'snake' is with respect to the
34 * edit distance matrix, see the linked paper in zstd_edist.h */
35 #define ZSTD_EDIST_SNAKE_THRESH 20
37 /* After how many iterations should we start to use the heuristic
38 * based on 'big' snakes */
39 #define ZSTD_EDIST_SNAKE_ITER_THRESH 200
41 /* After how many iterations should be just give up and take
42 * the best available edit script for this round */
43 #define ZSTD_EDIST_EXPENSIVE_THRESH 1024
45 /*-*************************************
47 ***************************************/
60 S32* forwardDiag; /* Entries of the forward diagonal stored here */
61 S32* backwardDiag; /* Entries of the backward diagonal stored here.
62 * Note: this buffer and the 'forwardDiag' buffer
63 * are contiguous. See the ZSTD_eDist_genSequences */
64 ZSTD_eDist_match* matches; /* Accumulate matches of length 1 in this buffer.
65 * In a subsequence post-processing step, we combine
66 * contiguous matches. */
71 S32 dictMid; /* The mid diagonal for the dictionary */
72 S32 srcMid; /* The mid diagonal for the source */
73 int lowUseHeuristics; /* Should we use heuristics for the low part */
74 int highUseHeuristics; /* Should we use heuristics for the high part */
75 } ZSTD_eDist_partition;
77 /*-*************************************
79 ***************************************/
81 static void ZSTD_eDist_diag(ZSTD_eDist_state* state,
82 ZSTD_eDist_partition* partition,
83 S32 dictLow, S32 dictHigh, S32 srcLow,
84 S32 srcHigh, int useHeuristics)
86 S32* const forwardDiag = state->forwardDiag;
87 S32* const backwardDiag = state->backwardDiag;
88 const BYTE* const dict = state->dict;
89 const BYTE* const src = state->src;
91 S32 const diagMin = dictLow - srcHigh;
92 S32 const diagMax = dictHigh - srcLow;
93 S32 const forwardMid = dictLow - srcLow;
94 S32 const backwardMid = dictHigh - srcHigh;
96 S32 forwardMin = forwardMid;
97 S32 forwardMax = forwardMid;
98 S32 backwardMin = backwardMid;
99 S32 backwardMax = backwardMid;
100 int odd = (forwardMid - backwardMid) & 1;
103 forwardDiag[forwardMid] = dictLow;
104 backwardDiag[backwardMid] = dictHigh;
106 /* Main loop for updating diag entries. Unless useHeuristics is
107 * set to false, this loop will run until it finds the minimal
109 for (iterations = 1;;iterations++) {
113 if (forwardMin > diagMin) {
115 forwardDiag[forwardMin - 1] = -1;
120 if (forwardMax < diagMax) {
122 forwardDiag[forwardMax + 1] = -1;
127 for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
130 S32 low = forwardDiag[diag - 1];
131 S32 high = forwardDiag[diag + 1];
132 S32 dictIdx0 = low < high ? high : low + 1;
134 for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag;
135 dictIdx < dictHigh && srcIdx < srcHigh && dict[dictIdx] == src[srcIdx];
136 dictIdx++, srcIdx++) continue;
138 if (dictIdx - dictIdx0 > ZSTD_EDIST_SNAKE_THRESH)
141 forwardDiag[diag] = dictIdx;
143 if (odd && backwardMin <= diag && diag <= backwardMax && backwardDiag[diag] <= dictIdx) {
144 partition->dictMid = dictIdx;
145 partition->srcMid = srcIdx;
146 partition->lowUseHeuristics = 0;
147 partition->highUseHeuristics = 0;
152 if (backwardMin > diagMin) {
154 backwardDiag[backwardMin - 1] = ZSTD_EDIST_DIAG_MAX;
159 if (backwardMax < diagMax) {
161 backwardDiag[backwardMax + 1] = ZSTD_EDIST_DIAG_MAX;
167 for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
170 S32 low = backwardDiag[diag - 1];
171 S32 high = backwardDiag[diag + 1];
172 S32 dictIdx0 = low < high ? low : high - 1;
174 for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag;
175 dictLow < dictIdx && srcLow < srcIdx && dict[dictIdx - 1] == src[srcIdx - 1];
176 dictIdx--, srcIdx--) continue;
178 if (dictIdx0 - dictIdx > ZSTD_EDIST_SNAKE_THRESH)
181 backwardDiag[diag] = dictIdx;
183 if (!odd && forwardMin <= diag && diag <= forwardMax && dictIdx <= forwardDiag[diag]) {
184 partition->dictMid = dictIdx;
185 partition->srcMid = srcIdx;
186 partition->lowUseHeuristics = 0;
187 partition->highUseHeuristics = 0;
195 /* Everything under this point is a heuristic. Using these will
196 * substantially speed up the match finding. In some cases, taking
197 * the total match finding time from several minutes to seconds.
198 * Of course, the caveat is that the edit script found may no longer
201 /* Big snake heuristic */
202 if (iterations > ZSTD_EDIST_SNAKE_ITER_THRESH && bigSnake) {
206 for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
207 S32 diagDiag = diag - forwardMid;
208 S32 dictIdx = forwardDiag[diag];
209 S32 srcIdx = dictIdx - diag;
210 S32 v = (dictIdx - dictLow) * 2 - diagDiag;
212 if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) {
214 && dictLow + ZSTD_EDIST_SNAKE_THRESH <= dictIdx && dictIdx <= dictHigh
215 && srcLow + ZSTD_EDIST_SNAKE_THRESH <= srcIdx && srcIdx <= srcHigh) {
217 for (k = 1; dict[dictIdx - k] == src[srcIdx - k]; k++) {
218 if (k == ZSTD_EDIST_SNAKE_THRESH) {
220 partition->dictMid = dictIdx;
221 partition->srcMid = srcIdx;
230 partition->lowUseHeuristics = 0;
231 partition->highUseHeuristics = 1;
239 for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
240 S32 diagDiag = diag - backwardMid;
241 S32 dictIdx = backwardDiag[diag];
242 S32 srcIdx = dictIdx - diag;
243 S32 v = (dictHigh - dictIdx) * 2 + diagDiag;
245 if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) {
247 && dictLow < dictIdx && dictIdx <= dictHigh - ZSTD_EDIST_SNAKE_THRESH
248 && srcLow < srcIdx && srcIdx <= srcHigh - ZSTD_EDIST_SNAKE_THRESH) {
250 for (k = 0; dict[dictIdx + k] == src[srcIdx + k]; k++) {
251 if (k == ZSTD_EDIST_SNAKE_THRESH - 1) {
253 partition->dictMid = dictIdx;
254 partition->srcMid = srcIdx;
263 partition->lowUseHeuristics = 1;
264 partition->highUseHeuristics = 0;
270 /* More general 'too expensive' heuristic */
271 if (iterations >= ZSTD_EDIST_EXPENSIVE_THRESH) {
272 S32 forwardDictSrcBest;
273 S32 forwardDictBest = 0;
274 S32 backwardDictSrcBest;
275 S32 backwardDictBest = 0;
277 forwardDictSrcBest = -1;
278 for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
279 S32 dictIdx = MIN(forwardDiag[diag], dictHigh);
280 S32 srcIdx = dictIdx - diag;
282 if (srcHigh < srcIdx) {
283 dictIdx = srcHigh + diag;
287 if (forwardDictSrcBest < dictIdx + srcIdx) {
288 forwardDictSrcBest = dictIdx + srcIdx;
289 forwardDictBest = dictIdx;
293 backwardDictSrcBest = ZSTD_EDIST_DIAG_MAX;
294 for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
295 S32 dictIdx = MAX(dictLow, backwardDiag[diag]);
296 S32 srcIdx = dictIdx - diag;
298 if (srcIdx < srcLow) {
299 dictIdx = srcLow + diag;
303 if (dictIdx + srcIdx < backwardDictSrcBest) {
304 backwardDictSrcBest = dictIdx + srcIdx;
305 backwardDictBest = dictIdx;
309 if ((dictHigh + srcHigh) - backwardDictSrcBest < forwardDictSrcBest - (dictLow + srcLow)) {
310 partition->dictMid = forwardDictBest;
311 partition->srcMid = forwardDictSrcBest - forwardDictBest;
312 partition->lowUseHeuristics = 0;
313 partition->highUseHeuristics = 1;
315 partition->dictMid = backwardDictBest;
316 partition->srcMid = backwardDictSrcBest - backwardDictBest;
317 partition->lowUseHeuristics = 1;
318 partition->highUseHeuristics = 0;
325 static void ZSTD_eDist_insertMatch(ZSTD_eDist_state* state,
326 S32 const dictIdx, S32 const srcIdx)
328 state->matches[state->nbMatches].dictIdx = dictIdx;
329 state->matches[state->nbMatches].srcIdx = srcIdx;
330 state->matches[state->nbMatches].matchLength = 1;
334 static int ZSTD_eDist_compare(ZSTD_eDist_state* state,
335 S32 dictLow, S32 dictHigh, S32 srcLow,
336 S32 srcHigh, int useHeuristics)
338 const BYTE* const dict = state->dict;
339 const BYTE* const src = state->src;
341 /* Found matches while traversing from the low end */
342 while (dictLow < dictHigh && srcLow < srcHigh && dict[dictLow] == src[srcLow]) {
343 ZSTD_eDist_insertMatch(state, dictLow, srcLow);
348 /* Found matches while traversing from the high end */
349 while (dictLow < dictHigh && srcLow < srcHigh && dict[dictHigh - 1] == src[srcHigh - 1]) {
350 ZSTD_eDist_insertMatch(state, dictHigh - 1, srcHigh - 1);
355 /* If the low and high end end up touching. If we wanted to make
356 * note of the differences like most diffing algorithms do, we would
357 * do so here. In our case, we're only concerned with matches
358 * Note: if you wanted to find the edit distance of the algorithm,
359 * you could just accumulate the cost for an insertion/deletion
361 if (dictLow == dictHigh) {
362 while (srcLow < srcHigh) {
363 /* Reaching this point means inserting src[srcLow] into
364 * the current position of dict */
367 } else if (srcLow == srcHigh) {
368 while (dictLow < dictHigh) {
369 /* Reaching this point means deleting dict[dictLow] from
370 * the current position of dict */
374 ZSTD_eDist_partition partition;
375 partition.dictMid = 0;
376 partition.srcMid = 0;
377 ZSTD_eDist_diag(state, &partition, dictLow, dictHigh,
378 srcLow, srcHigh, useHeuristics);
379 if (ZSTD_eDist_compare(state, dictLow, partition.dictMid,
380 srcLow, partition.srcMid, partition.lowUseHeuristics))
382 if (ZSTD_eDist_compare(state, partition.dictMid, dictHigh,
383 partition.srcMid, srcHigh, partition.highUseHeuristics))
390 static int ZSTD_eDist_matchComp(const void* p, const void* q)
392 S32 const l = ((ZSTD_eDist_match*)p)->srcIdx;
393 S32 const r = ((ZSTD_eDist_match*)q)->srcIdx;
397 /* The matches from the approach above will all be of the form
398 * (dictIdx, srcIdx, 1). This method combines contiguous matches
399 * of length MINMATCH or greater. Matches less than MINMATCH
401 static void ZSTD_eDist_combineMatches(ZSTD_eDist_state* state)
403 /* Create a new buffer to put the combined matches into
404 * and memcpy to state->matches after */
405 ZSTD_eDist_match* combinedMatches =
406 ZSTD_malloc(state->nbMatches * sizeof(ZSTD_eDist_match),
409 U32 nbCombinedMatches = 1;
412 /* Make sure that the srcIdx and dictIdx are in sorted order.
413 * The combination step won't work otherwise */
414 qsort(state->matches, state->nbMatches, sizeof(ZSTD_eDist_match), ZSTD_eDist_matchComp);
416 memcpy(combinedMatches, state->matches, sizeof(ZSTD_eDist_match));
417 for (i = 1; i < state->nbMatches; i++) {
418 ZSTD_eDist_match const match = state->matches[i];
419 ZSTD_eDist_match const combinedMatch =
420 combinedMatches[nbCombinedMatches - 1];
421 if (combinedMatch.srcIdx + combinedMatch.matchLength == match.srcIdx &&
422 combinedMatch.dictIdx + combinedMatch.matchLength == match.dictIdx) {
423 combinedMatches[nbCombinedMatches - 1].matchLength++;
425 /* Discard matches that are less than MINMATCH */
426 if (combinedMatches[nbCombinedMatches - 1].matchLength < MINMATCH) {
430 memcpy(combinedMatches + nbCombinedMatches,
431 state->matches + i, sizeof(ZSTD_eDist_match));
435 memcpy(state->matches, combinedMatches, nbCombinedMatches * sizeof(ZSTD_eDist_match));
436 state->nbMatches = nbCombinedMatches;
437 ZSTD_free(combinedMatches, ZSTD_defaultCMem);
440 static size_t ZSTD_eDist_convertMatchesToSequences(ZSTD_Sequence* sequences,
441 ZSTD_eDist_state* state)
443 const ZSTD_eDist_match* matches = state->matches;
444 size_t const nbMatches = state->nbMatches;
445 size_t const dictSize = state->dictSize;
446 size_t nbSequences = 0;
448 for (i = 0; i < nbMatches; i++) {
449 ZSTD_eDist_match const match = matches[i];
450 U32 const litLength = !i ? match.srcIdx :
451 match.srcIdx - (matches[i - 1].srcIdx + matches[i - 1].matchLength);
452 U32 const offset = (match.srcIdx + dictSize) - match.dictIdx;
453 U32 const matchLength = match.matchLength;
454 sequences[nbSequences].offset = offset;
455 sequences[nbSequences].litLength = litLength;
456 sequences[nbSequences].matchLength = matchLength;
462 /*-*************************************
464 ***************************************/
466 static size_t ZSTD_eDist_hamingDist(const BYTE* const a,
467 const BYTE* const b, size_t n)
471 for (i = 0; i < n; i++)
472 dist += a[i] != b[i];
476 /* This is a pretty naive recursive implementation that should only
477 * be used for quick tests obviously. Don't try and run this on a
478 * GB file or something. There are faster implementations. Use those
479 * if you need to run it for large files. */
480 static size_t ZSTD_eDist_levenshteinDist(const BYTE* const s,
481 size_t const sn, const BYTE* const t,
491 if (s[sn - 1] == t[tn - 1])
492 return ZSTD_eDist_levenshteinDist(
493 s, sn - 1, t, tn - 1);
495 a = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn - 1);
496 b = ZSTD_eDist_levenshteinDist(s, sn, t, tn - 1);
497 c = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn);
507 static void ZSTD_eDist_validateMatches(ZSTD_eDist_match* matches,
508 size_t const nbMatches, const BYTE* const dict,
509 size_t const dictSize, const BYTE* const src,
510 size_t const srcSize)
513 for (i = 0; i < nbMatches; i++) {
514 ZSTD_eDist_match match = matches[i];
515 U32 const dictIdx = match.dictIdx;
516 U32 const srcIdx = match.srcIdx;
517 U32 const matchLength = match.matchLength;
519 assert(dictIdx + matchLength < dictSize);
520 assert(srcIdx + matchLength < srcSize);
521 assert(!memcmp(dict + dictIdx, src + srcIdx, matchLength));
525 /*-*************************************
527 ***************************************/
529 size_t ZSTD_eDist_genSequences(ZSTD_Sequence* sequences,
530 const void* dict, size_t dictSize,
531 const void* src, size_t srcSize,
534 size_t const nbDiags = dictSize + srcSize + 3;
535 S32* buffer = ZSTD_malloc(nbDiags * 2 * sizeof(S32), ZSTD_defaultCMem);
536 ZSTD_eDist_state state;
537 size_t nbSequences = 0;
539 state.dict = (const BYTE*)dict;
540 state.src = (const BYTE*)src;
541 state.dictSize = dictSize;
542 state.srcSize = srcSize;
543 state.forwardDiag = buffer;
544 state.backwardDiag = buffer + nbDiags;
545 state.forwardDiag += srcSize + 1;
546 state.backwardDiag += srcSize + 1;
547 state.matches = ZSTD_malloc(srcSize * sizeof(ZSTD_eDist_match), ZSTD_defaultCMem);
550 ZSTD_eDist_compare(&state, 0, dictSize, 0, srcSize, 1);
551 ZSTD_eDist_combineMatches(&state);
552 nbSequences = ZSTD_eDist_convertMatchesToSequences(sequences, &state);
554 ZSTD_free(buffer, ZSTD_defaultCMem);
555 ZSTD_free(state.matches, ZSTD_defaultCMem);