git subrepo pull (merge) --force deps/libchdr
[pcsx_rearmed.git] / deps / libchdr / deps / zstd-1.5.5 / contrib / match_finders / zstd_edist.c
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10
11 /*-*************************************
12 *  Dependencies
13 ***************************************/
14
15 /* Currently relies on qsort when combining contiguous matches. This can probably 
16  * be avoided but would require changes to the algorithm. The qsort is far from 
17  * the bottleneck in this algorithm even for medium sized files so it's probably 
18  * not worth trying to address */ 
19 #include <stdlib.h>
20 #include <assert.h>
21
22 #include "zstd_edist.h"
23 #include "mem.h"
24
25 /*-*************************************
26 *  Constants
27 ***************************************/
28
29 /* Just a sential for the entries of the diagonal matrix */
30 #define ZSTD_EDIST_DIAG_MAX (S32)(1 << 30)
31
32 /* How large should a snake be to be considered a 'big' snake. 
33  * For an explanation of what a 'snake' is with respect to the 
34  * edit distance matrix, see the linked paper in zstd_edist.h */
35 #define ZSTD_EDIST_SNAKE_THRESH 20
36
37 /* After how many iterations should we start to use the heuristic
38  * based on 'big' snakes */
39 #define ZSTD_EDIST_SNAKE_ITER_THRESH 200
40
41 /* After how many iterations should be just give up and take 
42  * the best available edit script for this round */ 
43 #define ZSTD_EDIST_EXPENSIVE_THRESH 1024
44
45 /*-*************************************
46 *  Structures
47 ***************************************/
48
49 typedef struct {
50     U32 dictIdx;
51     U32 srcIdx;
52     U32 matchLength;
53 } ZSTD_eDist_match;
54
55 typedef struct {
56     const BYTE* dict;
57     const BYTE* src;
58     size_t dictSize;
59     size_t srcSize;
60     S32* forwardDiag;            /* Entries of the forward diagonal stored here */
61     S32* backwardDiag;           /* Entries of the backward diagonal stored here.
62                                   *   Note: this buffer and the 'forwardDiag' buffer 
63                                   *   are contiguous. See the ZSTD_eDist_genSequences */
64     ZSTD_eDist_match* matches;   /* Accumulate matches of length 1 in this buffer. 
65                                   *   In a subsequence post-processing step, we combine 
66                                   *   contiguous matches. */
67     U32 nbMatches;
68 } ZSTD_eDist_state;
69
70 typedef struct {
71     S32 dictMid;           /* The mid diagonal for the dictionary */
72     S32 srcMid;            /* The mid diagonal for the source */ 
73     int lowUseHeuristics;  /* Should we use heuristics for the low part */
74     int highUseHeuristics; /* Should we use heuristics for the high part */ 
75 } ZSTD_eDist_partition;
76
77 /*-*************************************
78 *  Internal
79 ***************************************/
80
81 static void ZSTD_eDist_diag(ZSTD_eDist_state* state,
82                     ZSTD_eDist_partition* partition,
83                     S32 dictLow, S32 dictHigh, S32 srcLow, 
84                     S32 srcHigh, int useHeuristics)
85 {
86     S32* const forwardDiag = state->forwardDiag;
87     S32* const backwardDiag = state->backwardDiag;
88     const BYTE* const dict = state->dict;
89     const BYTE* const src = state->src;
90
91     S32 const diagMin = dictLow - srcHigh;
92     S32 const diagMax = dictHigh - srcLow;
93     S32 const forwardMid = dictLow - srcLow;
94     S32 const backwardMid = dictHigh - srcHigh;
95
96     S32 forwardMin = forwardMid;
97     S32 forwardMax = forwardMid;
98     S32 backwardMin = backwardMid;
99     S32 backwardMax = backwardMid;
100     int odd = (forwardMid - backwardMid) & 1;
101     U32 iterations;
102
103     forwardDiag[forwardMid] = dictLow;
104     backwardDiag[backwardMid] = dictHigh;
105
106     /* Main loop for updating diag entries. Unless useHeuristics is 
107      * set to false, this loop will run until it finds the minimal 
108      * edit script */ 
109     for (iterations = 1;;iterations++) {
110         S32 diag;
111         int bigSnake = 0;
112         
113         if (forwardMin > diagMin) {
114             forwardMin--;
115             forwardDiag[forwardMin - 1] = -1;
116         } else {
117             forwardMin++;
118         }
119
120         if (forwardMax < diagMax) {
121             forwardMax++;
122             forwardDiag[forwardMax + 1] = -1;
123         } else {
124             forwardMax--;
125         }
126
127         for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
128             S32 dictIdx;
129             S32 srcIdx;
130             S32 low = forwardDiag[diag - 1];
131             S32 high = forwardDiag[diag + 1];
132             S32 dictIdx0 = low < high ? high : low + 1;
133
134             for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag;
135                 dictIdx < dictHigh && srcIdx < srcHigh && dict[dictIdx] == src[srcIdx];
136                 dictIdx++, srcIdx++) continue;
137
138             if (dictIdx - dictIdx0 > ZSTD_EDIST_SNAKE_THRESH)
139                 bigSnake = 1;
140
141             forwardDiag[diag] = dictIdx;
142
143             if (odd && backwardMin <= diag && diag <= backwardMax && backwardDiag[diag] <= dictIdx) {
144                 partition->dictMid = dictIdx;
145                 partition->srcMid = srcIdx;
146                 partition->lowUseHeuristics = 0;
147                 partition->highUseHeuristics = 0;
148                 return;
149             }
150         }
151
152         if (backwardMin > diagMin) {
153             backwardMin--;
154             backwardDiag[backwardMin - 1] = ZSTD_EDIST_DIAG_MAX;
155         } else {
156             backwardMin++;
157         }
158
159         if (backwardMax < diagMax) {
160             backwardMax++;
161             backwardDiag[backwardMax + 1] = ZSTD_EDIST_DIAG_MAX;
162         } else {
163             backwardMax--;
164         }
165
166
167         for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
168             S32 dictIdx;
169             S32 srcIdx;
170             S32 low = backwardDiag[diag - 1];
171             S32 high = backwardDiag[diag + 1];
172             S32 dictIdx0 = low < high ? low : high - 1;
173
174             for (dictIdx = dictIdx0, srcIdx = dictIdx0 - diag;
175                 dictLow < dictIdx && srcLow < srcIdx && dict[dictIdx - 1] == src[srcIdx - 1];
176                 dictIdx--, srcIdx--) continue;
177
178             if (dictIdx0 - dictIdx > ZSTD_EDIST_SNAKE_THRESH)
179                 bigSnake = 1;
180
181             backwardDiag[diag] = dictIdx;
182
183             if (!odd && forwardMin <= diag && diag <= forwardMax && dictIdx <= forwardDiag[diag]) {
184                 partition->dictMid = dictIdx;
185                 partition->srcMid = srcIdx;
186                 partition->lowUseHeuristics = 0;
187                 partition->highUseHeuristics = 0;
188                 return;
189             }
190         }
191
192         if (!useHeuristics)
193             continue;
194
195         /* Everything under this point is a heuristic. Using these will 
196          * substantially speed up the match finding. In some cases, taking 
197          * the total match finding time from several minutes to seconds.
198          * Of course, the caveat is that the edit script found may no longer 
199          * be optimal */ 
200
201         /* Big snake heuristic */ 
202         if (iterations > ZSTD_EDIST_SNAKE_ITER_THRESH && bigSnake) {
203             {
204                 S32 best = 0;
205                 
206                 for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
207                     S32 diagDiag = diag - forwardMid;
208                     S32 dictIdx = forwardDiag[diag];
209                     S32 srcIdx = dictIdx - diag;
210                     S32 v = (dictIdx - dictLow) * 2 - diagDiag;
211
212                     if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) {
213                         if (v > best 
214                           && dictLow + ZSTD_EDIST_SNAKE_THRESH <= dictIdx && dictIdx <= dictHigh
215                           && srcLow + ZSTD_EDIST_SNAKE_THRESH <= srcIdx && srcIdx <= srcHigh) {
216                             S32 k;
217                             for (k = 1; dict[dictIdx - k] == src[srcIdx - k]; k++) {
218                                 if (k == ZSTD_EDIST_SNAKE_THRESH) {
219                                     best = v;
220                                     partition->dictMid = dictIdx;
221                                     partition->srcMid = srcIdx;
222                                     break;
223                                 }
224                             }
225                         }
226                     }
227                 }
228
229                 if (best > 0) {
230                     partition->lowUseHeuristics = 0;
231                     partition->highUseHeuristics = 1;
232                     return;
233                 }
234             }
235
236             {
237                 S32 best = 0;
238
239                 for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
240                     S32 diagDiag = diag - backwardMid;
241                     S32 dictIdx = backwardDiag[diag];
242                     S32 srcIdx = dictIdx - diag;
243                     S32 v = (dictHigh - dictIdx) * 2 + diagDiag;
244
245                     if (v > 12 * (iterations + (diagDiag < 0 ? -diagDiag : diagDiag))) {
246                         if (v > best 
247                           && dictLow < dictIdx && dictIdx <= dictHigh - ZSTD_EDIST_SNAKE_THRESH
248                           && srcLow < srcIdx && srcIdx <= srcHigh - ZSTD_EDIST_SNAKE_THRESH) {
249                             int k;
250                             for (k = 0; dict[dictIdx + k] == src[srcIdx + k]; k++) {
251                                 if (k == ZSTD_EDIST_SNAKE_THRESH - 1) { 
252                                     best = v;
253                                     partition->dictMid = dictIdx;
254                                     partition->srcMid = srcIdx;
255                                     break; 
256                                 }
257                             }
258                         }
259                     }
260                 }
261
262                 if (best > 0) {
263                     partition->lowUseHeuristics = 1;
264                     partition->highUseHeuristics = 0;
265                     return;
266                 }
267             }
268         }
269
270         /* More general 'too expensive' heuristic */ 
271         if (iterations >= ZSTD_EDIST_EXPENSIVE_THRESH) {
272             S32 forwardDictSrcBest;
273             S32 forwardDictBest = 0;
274             S32 backwardDictSrcBest;
275             S32 backwardDictBest = 0;
276
277             forwardDictSrcBest = -1;
278             for (diag = forwardMax; diag >= forwardMin; diag -= 2) {
279                 S32 dictIdx = MIN(forwardDiag[diag], dictHigh);
280                 S32 srcIdx = dictIdx - diag;
281
282                 if (srcHigh < srcIdx) {
283                     dictIdx = srcHigh + diag;
284                     srcIdx = srcHigh;
285                 }
286
287                 if (forwardDictSrcBest < dictIdx + srcIdx) {
288                     forwardDictSrcBest = dictIdx + srcIdx;
289                     forwardDictBest = dictIdx;
290                 }
291             }
292
293             backwardDictSrcBest = ZSTD_EDIST_DIAG_MAX;
294             for (diag = backwardMax; diag >= backwardMin; diag -= 2) {
295                 S32 dictIdx = MAX(dictLow, backwardDiag[diag]);
296                 S32 srcIdx = dictIdx - diag;
297
298                 if (srcIdx < srcLow) {
299                     dictIdx = srcLow + diag;
300                     srcIdx = srcLow;
301                 }
302
303                 if (dictIdx + srcIdx < backwardDictSrcBest) {
304                     backwardDictSrcBest = dictIdx + srcIdx;
305                     backwardDictBest = dictIdx;
306                 }
307             }
308
309             if ((dictHigh + srcHigh) - backwardDictSrcBest < forwardDictSrcBest - (dictLow + srcLow)) {
310                 partition->dictMid = forwardDictBest;
311                 partition->srcMid = forwardDictSrcBest - forwardDictBest;
312                 partition->lowUseHeuristics = 0;
313                 partition->highUseHeuristics = 1;
314             } else {
315                 partition->dictMid = backwardDictBest;
316                 partition->srcMid = backwardDictSrcBest - backwardDictBest;
317                 partition->lowUseHeuristics = 1;
318                 partition->highUseHeuristics = 0;
319             }
320             return;
321         }
322     }
323 }
324
325 static void ZSTD_eDist_insertMatch(ZSTD_eDist_state* state, 
326                     S32 const dictIdx, S32 const srcIdx)
327 {
328     state->matches[state->nbMatches].dictIdx = dictIdx;
329     state->matches[state->nbMatches].srcIdx = srcIdx;
330     state->matches[state->nbMatches].matchLength = 1;
331     state->nbMatches++;
332 }
333
334 static int ZSTD_eDist_compare(ZSTD_eDist_state* state,
335                     S32 dictLow, S32 dictHigh, S32 srcLow,
336                     S32 srcHigh, int useHeuristics)
337 {
338     const BYTE* const dict = state->dict;
339     const BYTE* const src = state->src;
340
341     /* Found matches while traversing from the low end */ 
342     while (dictLow < dictHigh && srcLow < srcHigh && dict[dictLow] == src[srcLow]) {
343         ZSTD_eDist_insertMatch(state, dictLow, srcLow);
344         dictLow++;
345         srcLow++;
346     }
347
348     /* Found matches while traversing from the high end */
349     while (dictLow < dictHigh && srcLow < srcHigh && dict[dictHigh - 1] == src[srcHigh - 1]) {
350         ZSTD_eDist_insertMatch(state, dictHigh - 1, srcHigh - 1);
351         dictHigh--;
352         srcHigh--;
353     }
354     
355     /* If the low and high end end up touching. If we wanted to make 
356      * note of the differences like most diffing algorithms do, we would 
357      * do so here. In our case, we're only concerned with matches 
358      * Note: if you wanted to find the edit distance of the algorithm, 
359      *   you could just accumulate the cost for an insertion/deletion 
360      *   below. */ 
361     if (dictLow == dictHigh) {
362         while (srcLow < srcHigh) {
363             /* Reaching this point means inserting src[srcLow] into 
364              * the current position of dict */ 
365             srcLow++;
366         }
367     } else if (srcLow == srcHigh) {
368         while (dictLow < dictHigh) {
369             /* Reaching this point means deleting dict[dictLow] from 
370              * the current position of dict */ 
371             dictLow++;
372         }
373     } else {
374         ZSTD_eDist_partition partition;
375         partition.dictMid = 0;
376         partition.srcMid = 0;
377         ZSTD_eDist_diag(state, &partition, dictLow, dictHigh, 
378             srcLow, srcHigh, useHeuristics);
379         if (ZSTD_eDist_compare(state, dictLow, partition.dictMid, 
380           srcLow, partition.srcMid, partition.lowUseHeuristics))
381             return 1;
382         if (ZSTD_eDist_compare(state, partition.dictMid, dictHigh,
383           partition.srcMid, srcHigh, partition.highUseHeuristics))
384             return 1;
385     }
386
387     return 0;
388 }
389
390 static int ZSTD_eDist_matchComp(const void* p, const void* q)
391 {
392     S32 const l = ((ZSTD_eDist_match*)p)->srcIdx;
393     S32 const r = ((ZSTD_eDist_match*)q)->srcIdx;
394     return (l - r);
395 }
396
397 /* The matches from the approach above will all be of the form 
398  * (dictIdx, srcIdx, 1). This method combines contiguous matches 
399  * of length MINMATCH or greater. Matches less than MINMATCH 
400  * are discarded */ 
401 static void ZSTD_eDist_combineMatches(ZSTD_eDist_state* state)
402 {
403     /* Create a new buffer to put the combined matches into 
404      * and memcpy to state->matches after */ 
405     ZSTD_eDist_match* combinedMatches = 
406         ZSTD_malloc(state->nbMatches * sizeof(ZSTD_eDist_match), 
407         ZSTD_defaultCMem);
408
409     U32 nbCombinedMatches = 1;
410     size_t i;
411
412     /* Make sure that the srcIdx and dictIdx are in sorted order.
413      * The combination step won't work otherwise */ 
414     qsort(state->matches, state->nbMatches, sizeof(ZSTD_eDist_match), ZSTD_eDist_matchComp);
415
416     memcpy(combinedMatches, state->matches, sizeof(ZSTD_eDist_match));
417     for (i = 1; i < state->nbMatches; i++) {
418         ZSTD_eDist_match const match = state->matches[i];
419         ZSTD_eDist_match const combinedMatch = 
420             combinedMatches[nbCombinedMatches - 1];
421         if (combinedMatch.srcIdx + combinedMatch.matchLength == match.srcIdx && 
422           combinedMatch.dictIdx + combinedMatch.matchLength == match.dictIdx) {
423             combinedMatches[nbCombinedMatches - 1].matchLength++;
424         } else {
425             /* Discard matches that are less than MINMATCH */
426             if (combinedMatches[nbCombinedMatches - 1].matchLength < MINMATCH) {
427                 nbCombinedMatches--;
428             }
429
430             memcpy(combinedMatches + nbCombinedMatches, 
431                 state->matches + i, sizeof(ZSTD_eDist_match));
432             nbCombinedMatches++;
433         }
434     }
435     memcpy(state->matches, combinedMatches, nbCombinedMatches * sizeof(ZSTD_eDist_match));
436     state->nbMatches = nbCombinedMatches;
437     ZSTD_free(combinedMatches, ZSTD_defaultCMem);
438 }
439
440 static size_t ZSTD_eDist_convertMatchesToSequences(ZSTD_Sequence* sequences, 
441     ZSTD_eDist_state* state)
442 {
443     const ZSTD_eDist_match* matches = state->matches;
444     size_t const nbMatches = state->nbMatches;
445     size_t const dictSize = state->dictSize;
446     size_t nbSequences = 0;
447     size_t i;
448     for (i = 0; i < nbMatches; i++) {
449         ZSTD_eDist_match const match = matches[i];
450         U32 const litLength = !i ? match.srcIdx : 
451             match.srcIdx - (matches[i - 1].srcIdx + matches[i - 1].matchLength);
452         U32 const offset = (match.srcIdx + dictSize) - match.dictIdx;
453         U32 const matchLength = match.matchLength;
454         sequences[nbSequences].offset = offset;
455         sequences[nbSequences].litLength = litLength;
456         sequences[nbSequences].matchLength = matchLength;
457         nbSequences++;
458     }
459     return nbSequences;
460 }
461
462 /*-*************************************
463 *  Internal utils
464 ***************************************/
465
466 static size_t ZSTD_eDist_hamingDist(const BYTE* const a,
467                         const BYTE* const b, size_t n)
468 {
469     size_t i;
470     size_t dist = 0;
471     for (i = 0; i < n; i++)
472         dist += a[i] != b[i];
473     return dist; 
474 }
475
476 /* This is a pretty naive recursive implementation that should only
477  * be used for quick tests obviously. Don't try and run this on a 
478  * GB file or something. There are faster implementations. Use those
479  * if you need to run it for large files. */
480 static size_t ZSTD_eDist_levenshteinDist(const BYTE* const s,
481                         size_t const sn, const BYTE* const t,
482                         size_t const tn)
483 {
484     size_t a, b, c;
485
486     if (!sn)
487         return tn;
488     if (!tn)
489         return sn;
490     
491     if (s[sn - 1] == t[tn - 1])
492         return ZSTD_eDist_levenshteinDist(
493             s, sn - 1, t, tn - 1);
494     
495     a = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn - 1);
496     b = ZSTD_eDist_levenshteinDist(s, sn, t, tn - 1);
497     c = ZSTD_eDist_levenshteinDist(s, sn - 1, t, tn);
498
499     if (a > b)
500         a = b;
501     if (a > c)
502         a = c;
503     
504     return a + 1;
505 }
506
507 static void ZSTD_eDist_validateMatches(ZSTD_eDist_match* matches,
508                         size_t const nbMatches, const BYTE* const dict,
509                         size_t const dictSize, const BYTE* const src,
510                         size_t const srcSize)
511 {
512     size_t i;
513     for (i = 0; i < nbMatches; i++) {
514         ZSTD_eDist_match match = matches[i];
515         U32 const dictIdx = match.dictIdx;
516         U32 const srcIdx = match.srcIdx;
517         U32 const matchLength = match.matchLength;
518         
519         assert(dictIdx + matchLength < dictSize);
520         assert(srcIdx + matchLength < srcSize);
521         assert(!memcmp(dict + dictIdx, src + srcIdx, matchLength));
522     }
523 }
524
525 /*-*************************************
526 *  API
527 ***************************************/
528
529 size_t ZSTD_eDist_genSequences(ZSTD_Sequence* sequences, 
530                         const void* dict, size_t dictSize,
531                         const void* src, size_t srcSize,
532                         int useHeuristics)
533 {
534     size_t const nbDiags = dictSize + srcSize + 3;
535     S32* buffer = ZSTD_malloc(nbDiags * 2 * sizeof(S32), ZSTD_defaultCMem);
536     ZSTD_eDist_state state;
537     size_t nbSequences = 0;
538
539     state.dict = (const BYTE*)dict;
540     state.src = (const BYTE*)src;
541     state.dictSize = dictSize;
542     state.srcSize = srcSize;
543     state.forwardDiag = buffer;
544     state.backwardDiag = buffer + nbDiags;
545     state.forwardDiag += srcSize + 1;
546     state.backwardDiag += srcSize + 1;
547     state.matches = ZSTD_malloc(srcSize * sizeof(ZSTD_eDist_match), ZSTD_defaultCMem);
548     state.nbMatches = 0;
549
550     ZSTD_eDist_compare(&state, 0, dictSize, 0, srcSize, 1);
551     ZSTD_eDist_combineMatches(&state);
552     nbSequences = ZSTD_eDist_convertMatchesToSequences(sequences, &state);
553
554     ZSTD_free(buffer, ZSTD_defaultCMem);
555     ZSTD_free(state.matches, ZSTD_defaultCMem);
556
557     return nbSequences;
558 }