reloc patching, avoid int3 overfill
[ia32rtools.git] / tools / cmpmrg_text.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <linux/coff.h>
5 #include <assert.h>
6 #include <stdint.h>
7
8 #include "my_assert.h"
9
10 /* http://www.delorie.com/djgpp/doc/coff/ */
11
12 typedef struct {
13   unsigned short f_magic;         /* magic number             */
14   unsigned short f_nscns;         /* number of sections       */
15   unsigned int   f_timdat;        /* time & date stamp        */
16   unsigned int   f_symptr;        /* file pointer to symtab   */
17   unsigned int   f_nsyms;         /* number of symtab entries */
18   unsigned short f_opthdr;        /* sizeof(optional hdr)     */
19   unsigned short f_flags;         /* flags                    */
20 } FILHDR;
21
22 typedef struct {
23   unsigned short magic;          /* type of file                         */
24   unsigned short vstamp;         /* version stamp                        */
25   unsigned int   tsize;          /* text size in bytes, padded to FW bdry*/
26   unsigned int   dsize;          /* initialized data    "  "             */
27   unsigned int   bsize;          /* uninitialized data  "  "             */
28   unsigned int   entry;          /* entry pt.                            */
29   unsigned int   text_start;     /* base of text used for this file      */
30   unsigned int   data_start;     /* base of data used for this file      */
31 } AOUTHDR;
32
33 typedef struct {
34   char           s_name[8];  /* section name                     */
35   unsigned int   s_paddr;    /* physical address, aliased s_nlib */
36   unsigned int   s_vaddr;    /* virtual address                  */
37   unsigned int   s_size;     /* section size                     */
38   unsigned int   s_scnptr;   /* file ptr to raw data for section */
39   unsigned int   s_relptr;   /* file ptr to relocation           */
40   unsigned int   s_lnnoptr;  /* file ptr to line numbers         */
41   unsigned short s_nreloc;   /* number of relocation entries     */
42   unsigned short s_nlnno;    /* number of line number entries    */
43   unsigned int   s_flags;    /* flags                            */
44 } SCNHDR;
45
46 typedef struct {
47   unsigned int  r_vaddr;   /* address of relocation      */
48   unsigned int  r_symndx;  /* symbol we're adjusting for */
49   unsigned short r_type;    /* type of relocation         */
50 } __attribute__((packed)) RELOC;
51
52 typedef struct {
53   union {
54     char e_name[E_SYMNMLEN];
55     struct {
56       unsigned int e_zeroes;
57       unsigned int e_offset;
58     } e;
59   } e;
60   unsigned int e_value;
61   short e_scnum;
62   unsigned short e_type;
63   unsigned char e_sclass;
64   unsigned char e_numaux;
65 } __attribute__((packed)) SYMENT;
66
67 #define C_EXT 2
68
69 struct my_symtab {
70   unsigned int addr;
71   unsigned int fpos; // for patching
72   char *name;
73 };
74
75 struct my_sect_info {
76         long scnhdr_fofs;
77         long sect_fofs;
78         long reloc_fofs;
79         uint8_t *data;
80         long size;
81         RELOC *relocs;
82         long reloc_cnt;
83 };
84
85 static int symt_cmp(const void *p1_, const void *p2_)
86 {
87         const struct my_symtab *p1 = p1_, *p2 = p2_;
88         return p1->addr - p2->addr;
89 }
90
91 void parse_headers(FILE *f, unsigned int *base_out,
92         struct my_sect_info *sect_i,
93         struct my_symtab **symtab_out, long *sym_cnt)
94 {
95         struct my_symtab *symt_o = NULL;
96         char *stringtab = NULL;
97         unsigned int base = 0;
98         int text_scnum = 0;
99         long filesize;
100         char symname[9];
101         long opthdr_pos;
102         long reloc_size;
103         FILHDR hdr;
104         AOUTHDR opthdr;
105         SCNHDR scnhdr;
106         SYMENT syment;
107         int i, s, val;
108         int ret;
109         
110         ret = fseek(f, 0, SEEK_END);
111         my_assert(ret, 0);
112
113         filesize = ftell(f);
114
115         ret = fseek(f, 0, SEEK_SET);
116         my_assert(ret, 0);
117
118         ret = fread(&hdr, 1, sizeof(hdr), f);
119         my_assert(ret, sizeof(hdr));
120
121         if (hdr.f_magic == 0x5a4d) // MZ
122         {
123                 ret = fseek(f, 0x3c, SEEK_SET);
124                 my_assert(ret, 0);
125                 ret = fread(&val, 1, sizeof(val), f);
126                 my_assert(ret, sizeof(val));
127
128                 ret = fseek(f, val, SEEK_SET);
129                 my_assert(ret, 0);
130                 ret = fread(&val, 1, sizeof(val), f);
131                 my_assert(ret, sizeof(val));
132                 my_assert(val, 0x4550); // PE
133
134                 // should be COFF now
135                 ret = fread(&hdr, 1, sizeof(hdr), f);
136                 my_assert(ret, sizeof(hdr));
137         }
138
139         my_assert(hdr.f_magic, COFF_I386MAGIC);
140
141         if (hdr.f_opthdr != 0)
142         {
143                 opthdr_pos = ftell(f);
144
145                 if (hdr.f_opthdr < sizeof(opthdr))
146                         my_assert(1, 0);
147
148                 ret = fread(&opthdr, 1, sizeof(opthdr), f);
149                 my_assert(ret, sizeof(opthdr));
150                 my_assert(opthdr.magic, COFF_ZMAGIC);
151
152                 //printf("text_start: %x\n", opthdr.text_start);
153
154                 if (hdr.f_opthdr > sizeof(opthdr)) {
155                         ret = fread(&base, 1, sizeof(base), f);
156                         my_assert(ret, sizeof(base));
157                         //printf("base: %x\n", base);
158                 }
159                 ret = fseek(f, opthdr_pos + hdr.f_opthdr, SEEK_SET);
160                 my_assert(ret, 0);
161         }
162
163         // note: assuming first non-empty one is .text ..
164         for (s = 0; s < hdr.f_nscns; s++) {
165                 sect_i->scnhdr_fofs = ftell(f);
166
167                 ret = fread(&scnhdr, 1, sizeof(scnhdr), f);
168                 my_assert(ret, sizeof(scnhdr));
169
170                 if (scnhdr.s_size != 0) {
171                         text_scnum = s + 1;
172                         break;
173                 }
174         }
175         my_assert(s < hdr.f_nscns, 1);
176
177 #if 0
178         printf("f_nsyms:  %x\n", hdr.f_nsyms);
179         printf("s_name:   '%s'\n", scnhdr.s_name);
180         printf("s_vaddr:  %x\n", scnhdr.s_vaddr);
181         printf("s_size:   %x\n", scnhdr.s_size);
182         //printf("s_scnptr: %x\n", scnhdr.s_scnptr);
183         printf("s_nreloc: %x\n", scnhdr.s_nreloc);
184         printf("--\n");
185 #endif
186
187         ret = fseek(f, scnhdr.s_scnptr, SEEK_SET);
188         my_assert(ret, 0);
189
190         sect_i->data = malloc(scnhdr.s_size);
191         my_assert_not(sect_i->data, NULL);
192         ret = fread(sect_i->data, 1, scnhdr.s_size, f);
193         my_assert(ret, scnhdr.s_size);
194
195         sect_i->sect_fofs = scnhdr.s_scnptr;
196         sect_i->size = scnhdr.s_size;
197
198         // relocs
199         ret = fseek(f, scnhdr.s_relptr, SEEK_SET);
200         my_assert(ret, 0);
201
202         reloc_size = scnhdr.s_nreloc * sizeof(sect_i->relocs[0]);
203         sect_i->relocs = malloc(reloc_size + 1);
204         my_assert_not(sect_i->relocs, NULL);
205         ret = fread(sect_i->relocs, 1, reloc_size, f);
206         my_assert(ret, reloc_size);
207
208         sect_i->reloc_cnt = scnhdr.s_nreloc;
209         sect_i->reloc_fofs = scnhdr.s_relptr;
210
211         if (base != 0 && base_out != NULL)
212                 *base_out = base + scnhdr.s_vaddr;
213
214         if (symtab_out == NULL || sym_cnt == NULL)
215                 return;
216
217         // symtab
218         if (hdr.f_nsyms != 0) {
219                 symname[8] = 0;
220
221                 symt_o = malloc(hdr.f_nsyms * sizeof(symt_o[0]) + 1);
222                 my_assert_not(symt_o, NULL);
223
224                 ret = fseek(f, hdr.f_symptr
225                                 + hdr.f_nsyms * sizeof(syment), SEEK_SET);
226                 my_assert(ret, 0);
227                 ret = fread(&i, 1, sizeof(i), f);
228                 my_assert(ret, sizeof(i));
229                 my_assert((unsigned int)i < filesize, 1);
230
231                 stringtab = malloc(i);
232                 my_assert_not(stringtab, NULL);
233                 memset(stringtab, 0, 4);
234                 ret = fread(stringtab + 4, 1, i - 4, f);
235                 my_assert(ret, i - 4);
236
237                 ret = fseek(f, hdr.f_symptr, SEEK_SET);
238                 my_assert(ret, 0);
239         }
240
241         for (i = s = 0; i < hdr.f_nsyms; i++) {
242                 long pos = ftell(f);
243
244                 ret = fread(&syment, 1, sizeof(syment), f);
245                 my_assert(ret, sizeof(syment));
246
247                 strncpy(symname, syment.e.e_name, 8);
248                 //printf("%3d %2d %08x '%s'\n", syment.e_sclass,
249                 //      syment.e_scnum, syment.e_value, symname);
250
251                 if (syment.e_scnum != text_scnum || syment.e_sclass != C_EXT)
252                         continue;
253
254                 symt_o[s].addr = syment.e_value;
255                 symt_o[s].fpos = pos;
256                 if (syment.e.e.e_zeroes == 0)
257                         symt_o[s].name = stringtab + syment.e.e.e_offset;
258                 else
259                         symt_o[s].name = strdup(symname);
260                 s++;
261
262                 if (syment.e_numaux) {
263                         ret = fseek(f, syment.e_numaux * sizeof(syment),
264                                     SEEK_CUR);
265                         my_assert(ret, 0);
266                         i += syment.e_numaux;
267                 }
268         }
269
270         if (symt_o != NULL)
271                 qsort(symt_o, s, sizeof(symt_o[0]), symt_cmp);
272
273         *sym_cnt = s;
274         *symtab_out = symt_o;
275 }
276
277 static int handle_pad(uint8_t *d_obj, uint8_t *d_exe, int maxlen)
278 {
279         static const uint8_t p7[7] = { 0x8d, 0xa4, 0x24, 0x00, 0x00, 0x00, 0x00 };
280         static const uint8_t p6[6] = { 0x8d, 0x9b, 0x00, 0x00, 0x00, 0x00 };
281         static const uint8_t p5[5] = { 0x05, 0x00, 0x00, 0x00, 0x00 };
282         static const uint8_t p4[4] = { 0x8d, 0x64, 0x24, 0x00 };
283         static const uint8_t p3[3] = { 0x8d, 0x49, 0x00 };
284         static const uint8_t p2[2] = { 0x8b, 0xff };
285         static const uint8_t p1[1] = { 0x90 };
286         int len;
287         int i;
288
289         for (i = 0; i < maxlen; i++)
290                 if (d_exe[i] != 0xcc)
291                         break;
292
293         for (len = i; len > 0; )
294         {
295                 i = len;
296                 if (i > 7)
297                         i = 7;
298
299                 switch (i) {
300                 #define CASE(x) \
301                 case sizeof(p ## x): \
302                         if (memcmp(d_obj, p ## x, sizeof(p ## x))) \
303                                 return 0; \
304                         memset(d_obj, 0xcc, sizeof(p ## x)); \
305                         break;
306                 CASE(7)
307                 CASE(6)
308                 CASE(5)
309                 CASE(4)
310                 CASE(3)
311                 CASE(2)
312                 CASE(1)
313                 default:
314                         printf("%s: unhandled len: %d\n", __func__, len);
315                         return 0;
316                 #undef CASE
317                 }
318
319                 len -= i;
320                 d_obj += i;
321         }
322
323         return 1;
324 }
325
326 struct equiv_opcode {
327         signed char len;
328         signed char ofs;
329         short cmp_rm;
330         uint8_t v_masm[8];
331         uint8_t v_masm_mask[8];
332         uint8_t v_msvc[8];
333         uint8_t v_msvc_mask[8];
334 } equiv_ops[] = {
335         // cmp    $0x11,%ax
336         { 4, -1, 0,
337          { 0x66,0x83,0xf8,0x03 }, { 0xff,0xff,0xff,0x00 },
338          { 0x66,0x3d,0x03,0x00 }, { 0xff,0xff,0x00,0xff }, },
339         // lea    -0x1(%ebx,%eax,1),%esi // op mod/rm sib offs
340         // mov, test, imm grp 1
341         { 3, -2, 1,
342          { 0x8d,0x74,0x03 }, { 0xf0,0x07,0xc0 },
343          { 0x8d,0x74,0x18 }, { 0xf0,0x07,0xc0 }, },
344         // movzbl 0x58f24a(%eax,%ecx,1),%eax
345         { 4, -3, 1,
346          { 0x0f,0xb6,0x84,0x08 }, { 0xff,0xff,0x07,0xc0 },
347          { 0x0f,0xb6,0x84,0x01 }, { 0xff,0xff,0x07,0xc0 }, },
348         // inc/dec
349         { 3, -2, 1,
350          { 0xfe,0x4c,0x03 }, { 0xfe,0xff,0xc0 },
351          { 0xfe,0x4c,0x18 }, { 0xfe,0xff,0xc0 }, },
352         // cmp
353         { 3, -2, 1,
354          { 0x38,0x0c,0x0c }, { 0xff,0xff,0xc0 },
355          { 0x38,0x0c,0x30 }, { 0xff,0xff,0xc0 }, },
356         // test   %dl,%bl
357         { 2, -1, 1,
358          { 0x84,0xd3 }, { 0xfe,0xc0 },
359          { 0x84,0xda }, { 0xfe,0xc0 }, },
360         // cmp    r,r/m vs rm/r
361         { 2, 0, 1,
362          { 0x3a,0xca }, { 0xff,0xc0 },
363          { 0x38,0xd1 }, { 0xff,0xc0 }, },
364         // rep + 66 prefix
365         { 2, 0, 0,
366          { 0xf3,0x66 }, { 0xfe,0xff },
367          { 0x66,0xf3 }, { 0xff,0xfe }, },
368         // fadd   st, st(0) vs st(0), st
369         { 2, 0, 0,
370          { 0xd8,0xc0 }, { 0xff,0xf7 },
371          { 0xdc,0xc0 }, { 0xff,0xf7 }, },
372
373         // broad filters (may take too much..)
374         // testb  $0x4,0x1d(%esi,%eax,1)
375         // movb, push, ..
376         { 3, -2, 1,
377          { 0xf6,0x44,0x06 }, { 0x00,0x07,0xc0 },
378          { 0xf6,0x44,0x30 }, { 0x00,0x07,0xc0 }, },
379 };
380
381 static int cmp_mask(uint8_t *d, uint8_t *expect, uint8_t *mask, int len)
382 {
383         int i;
384
385         for (i = 0; i < len; i++)
386                 if ((d[i] & mask[i]) != (expect[i] & mask[i]))
387                         return 1;
388
389         return 0;
390 }
391
392 static int check_equiv(uint8_t *d_obj, uint8_t *d_exe, int maxlen)
393 {
394         uint8_t vo, ve, vo2, ve2;
395         int i, jo, je;
396         int len, ofs;
397
398         for (i = 0; i < sizeof(equiv_ops) / sizeof(equiv_ops[0]); i++)
399         {
400                 struct equiv_opcode *op = &equiv_ops[i];
401
402                 len = op->len;
403                 if (maxlen < len)
404                         continue;
405
406                 ofs = op->ofs;
407                 if (cmp_mask(d_obj + ofs, op->v_masm,
408                              op->v_masm_mask, len))
409                         continue;
410                 if (cmp_mask(d_exe + ofs, op->v_msvc,
411                              op->v_msvc_mask, len))
412                         continue;
413
414                 jo = je = 0;
415                 d_obj += ofs;
416                 d_exe += ofs;
417                 while (1)
418                 {
419                         for (; jo < len; jo++)
420                                 if (op->v_masm_mask[jo] != 0xff)
421                                         break;
422                         for (; je < len; je++)
423                                 if (op->v_msvc_mask[je] != 0xff)
424                                         break;
425
426                         if ((jo == len && je != len) || (jo != len && je == len)) {
427                                 printf("invalid equiv_ops\n");
428                                 return -1;
429                         }
430                         if (jo == len)
431                                 return len + ofs - 1; // matched
432
433                         // var byte
434                         vo = d_obj[jo] & ~op->v_masm_mask[jo];
435                         ve = d_exe[je] & ~op->v_msvc_mask[je];
436                         if (op->cmp_rm && op->v_masm_mask[jo] == 0xc0) {
437                                 vo2 = vo >> 3;
438                                 vo &= 7;
439                                 ve2 = ve & 7;
440                                 ve >>= 3;
441                                 if (vo != ve || vo2 != ve2)
442                                         return -1;
443                         }
444                         else {
445                                 if (vo != ve)
446                                         return -1;
447                         }
448
449                         jo++;
450                         je++;
451                 }
452         }
453
454         return -1;
455 }
456
457 static void fill_int3(unsigned char *d, int len)
458 {
459         while (len-- > 0) {
460                 if (*d == 0xcc)
461                         break;
462                 *d++ = 0xcc;
463         }
464 }
465
466 int main(int argc, char *argv[])
467 {
468         struct my_sect_info s_text_obj, s_text_exe;
469         struct my_symtab *syms_obj = NULL;
470         long sym_cnt_obj;
471         FILE *f_obj, *f_exe;
472         unsigned int base = 0, addr, end;
473         SCNHDR tmphdr;
474         long sztext_cmn;
475         int retval = 1;
476         int left;
477         int ret;
478         int i;
479
480         if (argc != 3) {
481                 printf("usage:\n%s <a_obj> <exe>\n", argv[0]);
482                 return 1;
483         }
484
485         f_obj = fopen(argv[1], "r+b");
486         if (f_obj == NULL) {
487                 fprintf(stderr, "%s", argv[1]);
488                 perror("");
489                 return 1;
490         }
491
492         f_exe = fopen(argv[2], "r");
493         if (f_exe == NULL) {
494                 fprintf(stderr, "%s", argv[2]);
495                 perror("");
496                 return 1;
497         }
498
499         parse_headers(f_obj, NULL, &s_text_obj, &syms_obj, &sym_cnt_obj);
500         parse_headers(f_exe, &base, &s_text_exe, NULL, NULL);
501
502         sztext_cmn = s_text_obj.size;
503         if (sztext_cmn > s_text_exe.size)
504                 sztext_cmn = s_text_exe.size;
505
506         if (sztext_cmn == 0) {
507                 printf("bad .text size(s): %ld, %ld\n",
508                         s_text_obj.size, s_text_exe.size);
509                 return 1;
510         }
511
512         for (i = 0; i < s_text_obj.reloc_cnt; i++)
513         {
514                 unsigned int a = s_text_obj.relocs[i].r_vaddr;
515                 //printf("%04x %08x\n", s_text_obj.relocs[i].r_type, a);
516
517                 switch (s_text_obj.relocs[i].r_type) {
518                 case 0x06: // RELOC_ADDR32
519                 case 0x14: // RELOC_REL32
520                         // must preserve stored val,
521                         // so trash exe so that cmp passes
522                         memcpy(s_text_exe.data + a, s_text_obj.data + a, 4);
523                         break;
524                 default:
525                         printf("unknown reloc %x @%08x/%08x\n",
526                                 s_text_obj.relocs[i].r_type, a, base + a);
527                         return 1;
528                 }
529         }
530
531         for (i = 0; i < sztext_cmn; i++)
532         {
533                 if (s_text_obj.data[i] == s_text_exe.data[i])
534                         continue;
535
536                 left = sztext_cmn - i;
537
538                 if (s_text_exe.data[i] == 0xcc) { // padding
539                         if (handle_pad(s_text_obj.data + i,
540                             s_text_exe.data + i, left))
541                                 continue;
542                 }
543
544                 ret = check_equiv(s_text_obj.data + i, s_text_exe.data + i, left);
545                 if (ret >= 0) {
546                         i += ret;
547                         continue;
548                 }
549
550                 printf("%x: %02x vs %02x\n", base + i,
551                         s_text_obj.data[i], s_text_exe.data[i]);
552                 goto out;
553         }
554
555         // fill removed funcs with 'int3'
556         for (i = 0; i < sym_cnt_obj; i++) {
557                 if (strncmp(syms_obj[i].name, "rm_", 3))
558                         continue;
559
560                 addr = syms_obj[i].addr;
561                 end = (i < sym_cnt_obj - 1)
562                         ? syms_obj[i + 1].addr : s_text_obj.size;
563                 if (addr >= s_text_obj.size || end > s_text_obj.size) {
564                         printf("addr OOR: %x-%x '%s'\n", addr, end,
565                                 syms_obj[i].name);
566                         goto out;
567                 }
568                 fill_int3(s_text_obj.data + addr, end - addr);
569         }
570
571         // remove relocs
572         for (i = 0; i < s_text_obj.reloc_cnt; i++) {
573                 addr = s_text_obj.relocs[i].r_vaddr;
574                 if (addr > s_text_obj.size - 4) {
575                         printf("reloc addr OOR: %x\n", addr);
576                         goto out;
577                 }
578                 if (*(unsigned int *)(s_text_obj.data + addr) == 0xcccccccc) {
579                         memmove(&s_text_obj.relocs[i],
580                                 &s_text_obj.relocs[i + 1],
581                                 (s_text_obj.reloc_cnt - i - 1)
582                                  * sizeof(s_text_obj.relocs[0]));
583                         i--;
584                         s_text_obj.reloc_cnt--;
585                 }
586         }
587
588         // patch .text
589         ret = fseek(f_obj, s_text_obj.sect_fofs, SEEK_SET);
590         my_assert(ret, 0);
591         ret = fwrite(s_text_obj.data, 1, s_text_obj.size, f_obj);
592         my_assert(ret, s_text_obj.size);
593
594         // patch relocs
595         ret = fseek(f_obj, s_text_obj.reloc_fofs, SEEK_SET);
596         my_assert(ret, 0);
597         ret = fwrite(s_text_obj.relocs, sizeof(s_text_obj.relocs[0]),
598                 s_text_obj.reloc_cnt, f_obj);
599         my_assert(ret, s_text_obj.reloc_cnt);
600
601         ret = fseek(f_obj, s_text_obj.scnhdr_fofs, SEEK_SET);
602         my_assert(ret, 0);
603         ret = fread(&tmphdr, 1, sizeof(tmphdr), f_obj);
604         my_assert(ret, sizeof(tmphdr));
605
606         tmphdr.s_nreloc = s_text_obj.reloc_cnt;
607
608         ret = fseek(f_obj, s_text_obj.scnhdr_fofs, SEEK_SET);
609         my_assert(ret, 0);
610         ret = fwrite(&tmphdr, 1, sizeof(tmphdr), f_obj);
611         my_assert(ret, sizeof(tmphdr));
612
613         fclose(f_obj);
614         fclose(f_exe);
615
616         retval = 0;
617 out:
618         return retval;
619 }