2 # ################################################################
3 # Copyright (c) Meta Platforms, Inc. and affiliates.
6 # This source code is licensed under both the BSD-style license (found in the
7 # LICENSE file in the root directory of this source tree) and the GPLv2 (found
8 # in the COPYING file in the root directory of this source tree).
9 # You may select, at your option, one of the above-listed licenses.
10 # ##########################################################################
18 from typing import Optional
21 INCLUDED_SUBDIRS = ["common", "compress", "decompress"]
30 "common/zstd_trace.h",
31 "compress/zstdmt_compress.h",
32 "compress/zstdmt_compress.c",
41 class FileLines(object):
42 def __init__(self, filename):
43 self.filename = filename
44 with open(self.filename, "r") as f:
45 self.lines = f.readlines()
48 with open(self.filename, "w") as f:
49 f.write("".join(self.lines))
52 class PartialPreprocessor(object):
54 Looks for simple ifdefs and ifndefs and replaces them.
56 Has fancy logic to handle translating elifs to ifs.
57 Only looks for macros in the first part of the expression with no
59 Does not handle multi-line macros (only looks in first line).
61 def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]):
62 MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)"
63 ELIF_GROUP = r"(?P<elif>el)?"
64 OP_GROUP = r"(?P<op>&&|\|\|)?"
66 self._defs = {macro:value for macro, value in defs}
67 self._replaces = {macro:value for macro, value in replaces}
68 self._defs.update(self._replaces)
69 self._undefs = set(undefs)
71 self._define = re.compile(r"\s*#\s*define")
72 self._if = re.compile(r"\s*#\s*if")
73 self._elif = re.compile(r"\s*#\s*(?P<elif>el)if")
74 self._else = re.compile(r"\s*#\s*(?P<else>else)")
75 self._endif = re.compile(r"\s*#\s*endif")
77 self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*")
78 self._if_defined = re.compile(
79 fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}"
81 self._if_defined_value = re.compile(
82 fr"\s*#\s*{ELIF_GROUP}if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*"
85 fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*"
86 fr"(?P<cmp>[=><!]+)\s*"
87 fr"(?P<value>[0-9]*)\s*"
88 fr"(?P<closep>\))?\s*"
90 self._if_true = re.compile(
91 fr"\s*#\s*{ELIF_GROUP}if\s+{MACRO_GROUP}\s*{OP_GROUP}"
94 self._c_comment = re.compile(r"/\*.*?\*/")
95 self._cpp_comment = re.compile(r"//")
97 def _log(self, *args, **kwargs):
98 print(*args, **kwargs)
100 def _strip_comments(self, line):
101 # First strip c-style comments (may include //)
103 m = self._c_comment.search(line)
106 line = line[:m.start()] + line[m.end():]
108 # Then strip cpp-style comments
109 m = self._cpp_comment.search(line)
111 line = line[:m.start()]
115 def _fixup_indentation(self, macro, replace: [str]):
116 if len(replace) == 0:
118 if len(replace) == 1 and self._define.match(replace[0]) is None:
119 # If there is only one line, only replace defines
125 if not line.startswith('#'):
128 replace = [line[1:] for line in replace]
130 min_spaces = len(replace[0])
133 for i, c in enumerate(line):
135 # Non-preprocessor line ==> skip the fixup
136 if not all_pound and c != '#':
140 min_spaces = min(min_spaces, spaces)
142 replace = [line[min_spaces:] for line in replace]
145 replace = ["#" + line for line in replace]
149 def _handle_if_block(self, macro, idx, is_true, prepend):
151 Remove the #if or #elif block starting on this line.
162 line = self._inlines[idx]
163 is_if = self._if.match(line) is not None
164 assert is_if or self._elif.match(line) is not None
172 while idx < len(self._inlines):
173 line = self._inlines[idx]
174 # Nested if statement
175 if self._if.match(line):
179 # We're inside a nested statement
181 if self._endif.match(line):
186 # We're at the original depth
188 # Looking only for an endif.
189 # We've found a true statement, but haven't
190 # completely elided the if block, so we just
191 # remove the remainder.
192 if state == REMOVE_REST:
193 if self._endif.match(line):
195 # Remove the endif because we took the first if
202 if state == KEEP_ONE:
203 m = self._elif.match(line)
204 if self._endif.match(line):
205 replace += self._inlines[start_idx + 1:idx]
209 if self._elif.match(line) or self._else.match(line):
210 replace += self._inlines[start_idx + 1:idx]
215 if state == REMOVE_ONE:
216 m = self._elif.match(line)
223 replace.append(line[:b] + line[e:])
226 m = self._else.match(line)
230 while self._endif.match(self._inlines[idx]) is None:
231 replace.append(self._inlines[idx])
236 if self._endif.match(line):
238 # Remove the endif because no other elifs
245 raise RuntimeError("Unterminated if block!")
247 replace = self._fixup_indentation(macro, replace)
249 self._log(f"\tHardwiring {macro}")
251 self._log(f"\t\t {self._inlines[start_idx - 1][:-1]}")
252 for x in range(start_idx, idx):
253 self._log(f"\t\t- {self._inlines[x][:-1]}")
255 self._log(f"\t\t+ {line[:-1]}")
256 if idx < len(self._inlines):
257 self._log(f"\t\t {self._inlines[idx][:-1]}")
261 def _preprocess_once(self):
265 while idx < len(self._inlines):
266 line = self._inlines[idx]
267 sline = self._strip_comments(line)
268 m = self._ifdef.fullmatch(sline)
271 m = self._if_defined_value.fullmatch(sline)
273 m = self._if_defined.match(sline)
275 m = self._if_true.match(sline)
276 if_true = (m is not None)
278 outlines.append(line)
282 groups = m.groupdict()
283 macro = groups['macro']
284 op = groups.get('op')
286 if not (macro in self._defs or macro in self._undefs):
287 outlines.append(line)
291 defined = macro in self._defs
293 # Needed variables set:
294 # resolved: Is the statement fully resolved?
295 # is_true: If resolved, is the statement true?
299 outlines.append(line)
303 defined_value = self._defs[macro]
306 defined_value = int(defined_value)
313 is_true = (defined_value != 0)
315 if resolved and op is not None:
317 resolved = not is_true
323 ifdef = groups.get('not') is None
324 elseif = groups.get('elif') is not None
326 macro2 = groups.get('macro2')
327 cmp = groups.get('cmp')
328 value = groups.get('value')
329 openp = groups.get('openp')
330 closep = groups.get('closep')
332 is_true = (ifdef == defined)
336 resolved = not is_true
341 if macro2 is not None and not resolved:
342 assert ifdef and defined and op == '&&' and cmp is not None
343 # If the statement is true, but we have a single value check, then
345 defined_value = self._defs[macro]
348 defined_value = int(defined_value)
356 ((openp is None) == (closep is None)) and
361 is_true = defined_value < value
363 is_true = defined_value <= value
365 is_true = defined_value == value
367 is_true = defined_value != value
369 is_true = defined_value >= value
371 is_true = defined_value > value
375 if op is not None and not resolved:
376 # Remove the first op in the line + spaces
382 needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)")
383 match = needle.match(line)
384 assert match is not None
385 newline = line[:match.end('if')] + line[match.end('op'):]
387 self._log(f"\tHardwiring partially resolved {macro}")
388 self._log(f"\t\t- {line[:-1]}")
389 self._log(f"\t\t+ {newline[:-1]}")
391 outlines.append(newline)
395 # Skip any statements we cannot fully compute
397 outlines.append(line)
402 if macro in self._replaces:
405 value = self._replaces.pop(macro)
406 prepend = [f"#define {macro} {value}\n"]
408 idx, replace = self._handle_if_block(macro, idx, is_true, prepend)
412 return changed, outlines
414 def preprocess(self, filename):
415 with open(filename, 'r') as f:
416 self._inlines = f.readlines()
421 changed, outlines = self._preprocess_once()
422 self._inlines = outlines
424 with open(filename, 'w') as f:
425 f.write(''.join(self._inlines))
428 class Freestanding(object):
430 self, zstd_deps: str, mem: str, source_lib: str, output_lib: str,
431 external_xxhash: bool, xxh64_state: Optional[str],
432 xxh64_prefix: Optional[str], rewritten_includes: [(str, str)],
433 defs: [(str, Optional[str])], replaces: [(str, str)],
434 undefs: [str], excludes: [str], seds: [str], spdx: bool,
436 self._zstd_deps = zstd_deps
438 self._src_lib = source_lib
439 self._dst_lib = output_lib
440 self._external_xxhash = external_xxhash
441 self._xxh64_state = xxh64_state
442 self._xxh64_prefix = xxh64_prefix
443 self._rewritten_includes = rewritten_includes
445 self._replaces = replaces
446 self._undefs = undefs
447 self._excludes = excludes
451 def _dst_lib_file_paths(self):
453 Yields all the file paths in the dst_lib.
455 for root, dirname, filenames in os.walk(self._dst_lib):
456 for filename in filenames:
457 filepath = os.path.join(root, filename)
460 def _log(self, *args, **kwargs):
461 print(*args, **kwargs)
463 def _copy_file(self, lib_path):
464 suffixes = [".c", ".h", ".S"]
465 if not any((lib_path.endswith(suffix) for suffix in suffixes)):
467 if lib_path in SKIPPED_FILES:
468 self._log(f"\tSkipping file: {lib_path}")
470 if self._external_xxhash and lib_path in XXHASH_FILES:
471 self._log(f"\tSkipping xxhash file: {lib_path}")
474 src_path = os.path.join(self._src_lib, lib_path)
475 dst_path = os.path.join(self._dst_lib, lib_path)
476 self._log(f"\tCopying: {src_path} -> {dst_path}")
477 shutil.copyfile(src_path, dst_path)
479 def _copy_source_lib(self):
480 self._log("Copying source library into output library")
482 assert os.path.exists(self._src_lib)
483 os.makedirs(self._dst_lib, exist_ok=True)
484 self._copy_file("zstd.h")
485 self._copy_file("zstd_errors.h")
486 for subdir in INCLUDED_SUBDIRS:
487 src_dir = os.path.join(self._src_lib, subdir)
488 dst_dir = os.path.join(self._dst_lib, subdir)
490 assert os.path.exists(src_dir)
491 os.makedirs(dst_dir, exist_ok=True)
493 for filename in os.listdir(src_dir):
494 lib_path = os.path.join(subdir, filename)
495 self._copy_file(lib_path)
497 def _copy_zstd_deps(self):
498 dst_zstd_deps = os.path.join(self._dst_lib, "common", "zstd_deps.h")
499 self._log(f"Copying zstd_deps: {self._zstd_deps} -> {dst_zstd_deps}")
500 shutil.copyfile(self._zstd_deps, dst_zstd_deps)
503 dst_mem = os.path.join(self._dst_lib, "common", "mem.h")
504 self._log(f"Copying mem: {self._mem} -> {dst_mem}")
505 shutil.copyfile(self._mem, dst_mem)
507 def _hardwire_preprocessor(self, name: str, value: Optional[str] = None, undef=False):
509 If value=None then hardwire that it is defined, but not what the value is.
510 If undef=True then value must be None.
511 If value='' then the macro is defined to '' exactly.
513 assert not (undef and value is not None)
514 for filepath in self._dst_lib_file_paths():
515 file = FileLines(filepath)
517 def _hardwire_defines(self):
518 self._log("Hardwiring macros")
519 partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs)
520 for filepath in self._dst_lib_file_paths():
521 partial_preprocessor.preprocess(filepath)
523 def _remove_excludes(self):
524 self._log("Removing excluded sections")
525 for exclude in self._excludes:
526 self._log(f"\tRemoving excluded sections for: {exclude}")
527 begin_re = re.compile(f"BEGIN {exclude}")
528 end_re = re.compile(f"END {exclude}")
529 for filepath in self._dst_lib_file_paths():
530 file = FileLines(filepath)
534 for line in file.lines:
535 if emit and begin_re.search(line) is not None:
536 assert end_re.search(line) is None
539 outlines.append(line)
542 if end_re.search(line) is not None:
543 assert begin_re.search(line) is None
544 self._log(f"\t\tRemoving excluded section: {exclude}")
546 self._log(f"\t\t\t- {s}")
550 raise RuntimeError("Excluded section unfinished!")
551 file.lines = outlines
554 def _rewrite_include(self, original, rewritten):
555 self._log(f"\tRewriting include: {original} -> {rewritten}")
556 regex = re.compile(f"\\s*#\\s*include\\s*(?P<include>{original})")
557 for filepath in self._dst_lib_file_paths():
558 file = FileLines(filepath)
559 for i, line in enumerate(file.lines):
560 match = regex.match(line)
563 s = match.start('include')
564 e = match.end('include')
565 file.lines[i] = line[:s] + rewritten + line[e:]
568 def _rewrite_includes(self):
569 self._log("Rewriting includes")
570 for original, rewritten in self._rewritten_includes:
571 self._rewrite_include(original, rewritten)
573 def _replace_xxh64_prefix(self):
574 if self._xxh64_prefix is None:
576 self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}")
578 if self._xxh64_state is not None:
580 (re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state)
582 if self._xxh64_prefix is not None:
584 (re.compile(r"([^\w]|^)(?P<orig>XXH64)[\(_]"), self._xxh64_prefix)
586 for filepath in self._dst_lib_file_paths():
587 file = FileLines(filepath)
588 for i, line in enumerate(file.lines):
590 for regex, replacement in replacements:
591 match = regex.search(line)
592 while match is not None:
594 b = match.start('orig')
595 e = match.end('orig')
596 line = line[:b] + replacement + line[e:]
597 match = regex.search(line)
599 self._log(f"\t- {file.lines[i][:-1]}")
600 self._log(f"\t+ {line[:-1]}")
604 def _parse_sed(self, sed):
607 match = re.fullmatch(f's{delim}(.+){delim}(.*){delim}(.*)', sed)
608 assert match is not None
609 regex = re.compile(match.group(1))
610 format_str = match.group(2)
611 is_global = match.group(3) == 'g'
612 return regex, format_str, is_global
614 def _process_sed(self, sed):
615 self._log(f"Processing sed: {sed}")
616 regex, format_str, is_global = self._parse_sed(sed)
618 for filepath in self._dst_lib_file_paths():
619 file = FileLines(filepath)
620 for i, line in enumerate(file.lines):
623 match = regex.search(line)
626 replacement = format_str.format(match.groups(''), match.groupdict(''))
629 line = line[:b] + replacement + line[e:]
634 self._log(f"\t- {file.lines[i][:-1]}")
635 self._log(f"\t+ {line[:-1]}")
639 def _process_seds(self):
640 self._log("Processing seds")
641 for sed in self._seds:
642 self._process_sed(sed)
644 def _process_spdx(self):
647 self._log("Processing spdx")
648 SPDX_C = "// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause\n"
649 SPDX_H_S = "/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */\n"
650 for filepath in self._dst_lib_file_paths():
651 file = FileLines(filepath)
652 if file.lines[0] == SPDX_C or file.lines[0] == SPDX_H_S:
654 for line in file.lines:
655 if "SPDX-License-Identifier" in line:
656 raise RuntimeError(f"Unexpected SPDX license identifier: {file.filename} {repr(line)}")
657 if file.filename.endswith(".c"):
658 file.lines.insert(0, SPDX_C)
659 elif file.filename.endswith(".h") or file.filename.endswith(".S"):
660 file.lines.insert(0, SPDX_H_S)
662 raise RuntimeError(f"Unexpected file extension: {file.filename}")
668 self._copy_source_lib()
669 self._copy_zstd_deps()
671 self._hardwire_defines()
672 self._remove_excludes()
673 self._rewrite_includes()
674 self._replace_xxh64_prefix()
679 def parse_optional_pair(defines: [str]) -> [(str, Optional[str])]:
681 for define in defines:
682 parsed = define.split('=')
684 output.append((parsed[0], None))
685 elif len(parsed) == 2:
686 output.append((parsed[0], parsed[1]))
688 raise RuntimeError(f"Bad define: {define}")
692 def parse_pair(rewritten_includes: [str]) -> [(str, str)]:
694 for rewritten_include in rewritten_includes:
695 parsed = rewritten_include.split('=')
697 output.append((parsed[0], parsed[1]))
699 raise RuntimeError(f"Bad rewritten include: {rewritten_include}")
704 def main(name, args):
705 parser = argparse.ArgumentParser(prog=name)
706 parser.add_argument("--zstd-deps", default="zstd_deps.h", help="Zstd dependencies file")
707 parser.add_argument("--mem", default="mem.h", help="Memory module")
708 parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library")
709 parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library")
710 parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.")
711 parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'")
712 parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64")
713 parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')")
714 parser.add_argument("--sed", default=[], dest="seds", action="append", help="Apply a sed replacement. Format: `s/REGEX/FORMAT/[g]`. REGEX is a Python regex. FORMAT is a Python format string formatted by the regex dict.")
715 parser.add_argument("--spdx", action="store_true", help="Add SPDX License Identifiers")
716 parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)")
717 parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed multiple times)")
718 parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition")
719 parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'")
720 args = parser.parse_args(args)
722 # Always remove threading
723 if "ZSTD_MULTITHREAD" not in args.undefs:
724 args.undefs.append("ZSTD_MULTITHREAD")
726 args.defs = parse_optional_pair(args.defs)
727 for name, _ in args.defs:
728 if name in args.undefs:
729 raise RuntimeError(f"{name} is both defined and undefined!")
731 # Always set tracing to 0
732 if "ZSTD_NO_TRACE" not in (arg[0] for arg in args.defs):
733 args.defs.append(("ZSTD_NO_TRACE", None))
734 args.defs.append(("ZSTD_TRACE", "0"))
736 args.replaces = parse_pair(args.replaces)
737 for name, _ in args.replaces:
738 if name in args.undefs or name in args.defs:
739 raise RuntimeError(f"{name} is both replaced and (un)defined!")
741 args.rewritten_includes = parse_pair(args.rewritten_includes)
743 external_xxhash = False
744 if args.xxhash is not None:
745 external_xxhash = True
746 args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash))
748 if args.xxh64_prefix is not None:
749 if not external_xxhash:
750 raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided")
752 if args.xxh64_state is not None:
753 if not external_xxhash:
754 raise RuntimeError("--xxh64-state may only be used with --xxhash provided")
764 args.rewritten_includes,
773 if __name__ == "__main__":
774 main(sys.argv[0], sys.argv[1:])