translate: detect more invalid writes to args
[ia32rtools.git] / tools / translate.c
index ecd972f..66f2a50 100644 (file)
@@ -1,3 +1,11 @@
+/*
+ * ia32rtools
+ * (C) notaz, 2013,2014
+ *
+ * This work is licensed under the terms of 3-clause BSD license.
+ * See COPYING file in the top-level directory.
+ */
+
 #define _GNU_SOURCE
 #include <stdio.h>
 #include <stdlib.h>
@@ -46,6 +54,7 @@ enum op_flags {
   OPF_ATAIL  = (1 << 14), /* tail call with reused arg frame */
   OPF_32BIT  = (1 << 15), /* 32bit division */
   OPF_LOCK   = (1 << 16), /* op has lock prefix */
+  OPF_VAPUSH = (1 << 17), /* vararg ptr push (as call arg) */
 };
 
 enum op_op {
@@ -141,11 +150,12 @@ struct parsed_op {
   unsigned char pfo;
   unsigned char pfo_inv;
   unsigned char operand_cnt;
-  unsigned char pad;
+  unsigned char p_argnum; // push: altered before call arg #
+  unsigned char p_argpass;// push: arg of host func
+  unsigned char pad[3];
   int regmask_src;        // all referensed regs
   int regmask_dst;
   int pfomask;            // flagop: parsed_flag_op that can't be delayed
-  int argnum;             // push: altered before call arg #
   int cc_scratch;         // scratch storage during analysis
   int bt_i;               // branch target for branches
   struct parsed_data *btj;// branch targets for jumptables
@@ -464,10 +474,10 @@ static int guess_lmod_from_c_type(enum opr_lenmod *lmod,
   static const char *dword_types[] = {
     "int", "_DWORD", "UINT_PTR", "DWORD",
     "WPARAM", "LPARAM", "UINT", "__int32",
-    "LONG", "HIMC",
+    "LONG", "HIMC", "BOOL",
   };
   static const char *word_types[] = {
-    "uint16_t", "int16_t", "_WORD",
+    "uint16_t", "int16_t", "_WORD", "WORD",
     "unsigned __int16", "__int16",
   };
   static const char *byte_types[] = {
@@ -1361,7 +1371,7 @@ static void parse_stack_access(struct parsed_op *po,
     *bp_arg_out = bp_arg;
 }
 
-static void stack_frame_access(struct parsed_op *po,
+static int stack_frame_access(struct parsed_op *po,
   struct parsed_opr *popr, char *buf, size_t buf_size,
   const char *name, const char *cast, int is_src, int is_lea)
 {
@@ -1373,6 +1383,7 @@ static void stack_frame_access(struct parsed_op *po,
   int unaligned = 0;
   int stack_ra = 0;
   int offset = 0;
+  int retval = -1;
   int sf_ofs;
   int lim;
 
@@ -1394,7 +1405,7 @@ static void stack_frame_access(struct parsed_op *po,
         if (cast[0] == 0)
           cast = "(u32)";
         snprintf(buf, buf_size, "%sap", cast);
-        return;
+        return -1;
       }
       ferr(po, "offset %d (%s,%d) doesn't map to any arg\n",
         offset, bp_arg, arg_i);
@@ -1413,6 +1424,7 @@ static void stack_frame_access(struct parsed_op *po,
       ferr(po, "arg %d not in prototype?\n", arg_i);
 
     popr->is_ptr = g_func_pp->arg[i].type.is_ptr;
+    retval = i;
 
     switch (popr->lmod)
     {
@@ -1484,7 +1496,8 @@ static void stack_frame_access(struct parsed_op *po,
     // common problem
     guess_lmod_from_c_type(&tmp_lmod, &g_func_pp->arg[i].type);
     if (tmp_lmod != OPLM_DWORD
-      && (unaligned || (!is_src && tmp_lmod < popr->lmod)))
+      && (unaligned || (!is_src && lmod_bytes(po, tmp_lmod)
+                         < lmod_bytes(po, popr->lmod) + (offset & 3))))
     {
       ferr(po, "bp_arg arg%d/w offset %d and type '%s' is too small\n",
         i + 1, offset, g_func_pp->arg[i].type.name);
@@ -1547,12 +1560,16 @@ static void stack_frame_access(struct parsed_op *po,
       ferr(po, "bp_stack bad lmod: %d\n", popr->lmod);
     }
   }
+
+  return retval;
 }
 
 static void check_func_pp(struct parsed_op *po,
   const struct parsed_proto *pp, const char *pfx)
 {
+  enum opr_lenmod tmp_lmod;
   char buf[256];
+  int ret, i;
 
   if (pp->argc_reg != 0) {
     if (/*!g_allow_regfunc &&*/ !pp->is_fastcall) {
@@ -1563,6 +1580,18 @@ static void check_func_pp(struct parsed_op *po,
       ferr(po, "%s: %d reg arg(s) with %d stack arg(s)\n",
         pfx, pp->argc_reg, pp->argc_stack);
   }
+
+  // fptrs must use 32bit args, callsite might have no information and
+  // lack a cast to smaller types, which results in incorrectly masked
+  // args passed (callee may assume masked args, it does on ARM)
+  if (!pp->is_oslib) {
+    for (i = 0; i < pp->argc; i++) {
+      ret = guess_lmod_from_c_type(&tmp_lmod, &pp->arg[i].type);
+      if (ret && tmp_lmod != OPLM_DWORD)
+        ferr(po, "reference to %s with arg%d '%s'\n", pp->name,
+          i + 1, pp->arg[i].type.name);
+    }
+  }
 }
 
 static const char *check_label_read_ref(struct parsed_op *po,
@@ -1988,7 +2017,7 @@ static int scan_for_pop(int i, int opcnt, const char *reg,
     }
 
     if ((po->flags & OPF_RMD)
-        || (po->op == OP_PUSH && po->argnum != 0)) // arg push
+        || (po->op == OP_PUSH && po->p_argnum != 0)) // arg push
       continue;
 
     if ((po->flags & OPF_JMP) && po->op != OP_CALL) {
@@ -2574,8 +2603,10 @@ static const struct parsed_proto *resolve_icall(int i, int opcnt,
   return pp;
 }
 
-static int try_resolve_const(int i, const struct parsed_opr *opr,
-  int magic, unsigned int *val)
+// find an instruction that changed opr before i op
+// *op_i must be set to -1
+static int resolve_origin(int i, const struct parsed_opr *opr,
+  int magic, int *op_i)
 {
   struct label_ref *lr;
   int ret = 0;
@@ -2586,7 +2617,7 @@ static int try_resolve_const(int i, const struct parsed_opr *opr,
     if (g_labels[i][0] != 0) {
       lr = &g_label_refs[i];
       for (; lr != NULL; lr = lr->next)
-        ret |= try_resolve_const(lr->i, opr, magic, val);
+        ret |= resolve_origin(lr->i, opr, magic, op_i);
       if (i > 0 && LAST_OP(i - 1))
         return ret;
     }
@@ -2603,12 +2634,36 @@ static int try_resolve_const(int i, const struct parsed_opr *opr,
       continue;
     if (!is_opr_modified(opr, &ops[i]))
       continue;
+
+    if (*op_i >= 0) {
+      if (*op_i == i)
+        return 1;
+      // XXX: could check if the other op does the same
+      return -1;
+    }
+
+    *op_i = i;
+    return 1;
+  }
+}
+
+static int try_resolve_const(int i, const struct parsed_opr *opr,
+  int magic, unsigned int *val)
+{
+  int s_i = -1;
+  int ret = 0;
+
+  ret = resolve_origin(i, opr, magic, &s_i);
+  if (ret == 1) {
+    i = s_i;
     if (ops[i].op != OP_MOV && ops[i].operand[1].type != OPT_CONST)
       return -1;
 
     *val = ops[i].operand[1].val;
     return 1;
   }
+
+  return -1;
 }
 
 static int collect_call_args_r(struct parsed_op *po, int i,
@@ -2618,8 +2673,11 @@ static int collect_call_args_r(struct parsed_op *po, int i,
   struct parsed_proto *pp_tmp;
   struct label_ref *lr;
   int need_to_save_current;
+  int save_args;
   int ret = 0;
-  int j;
+  int reg;
+  char buf[32];
+  int j, k;
 
   if (i < 0) {
     ferr(po, "dead label encountered\n");
@@ -2712,6 +2770,11 @@ static int collect_call_args_r(struct parsed_op *po, int i,
 
       pp->arg[arg].datap = &ops[j];
       need_to_save_current = 0;
+      save_args = 0;
+      reg = -1;
+      if (ops[j].operand[0].type == OPT_REG)
+        reg = ops[j].operand[0].reg;
+
       if (!need_op_saving) {
         ret = scan_for_mod(&ops[j], j + 1, i, 1);
         need_to_save_current = (ret >= 0);
@@ -2719,15 +2782,15 @@ static int collect_call_args_r(struct parsed_op *po, int i,
       if (need_op_saving || need_to_save_current) {
         // mark this push as one that needs operand saving
         ops[j].flags &= ~OPF_RMD;
-        if (ops[j].argnum == 0) {
-          ops[j].argnum = arg + 1;
-          *save_arg_vars |= 1 << arg;
+        if (ops[j].p_argnum == 0) {
+          ops[j].p_argnum = arg + 1;
+          save_args |= 1 << arg;
         }
-        else if (ops[j].argnum < arg + 1)
-          ferr(&ops[j], "argnum conflict (%d<%d) for '%s'\n",
-            ops[j].argnum, arg + 1, pp->name);
+        else if (ops[j].p_argnum < arg + 1)
+          ferr(&ops[j], "p_argnum conflict (%d<%d) for '%s'\n",
+            ops[j].p_argnum, arg + 1, pp->name);
       }
-      else if (ops[j].argnum == 0)
+      else if (ops[j].p_argnum == 0)
         ops[j].flags |= OPF_RMD;
 
       // some PUSHes are reused by different calls on other branches,
@@ -2738,6 +2801,49 @@ static int collect_call_args_r(struct parsed_op *po, int i,
 
       ops[j].flags &= ~OPF_RSAVE;
 
+      // check for __VALIST
+      if (!pp->is_unresolved && pp->arg[arg].type.is_va_list) {
+        k = -1;
+        ret = resolve_origin(j, &ops[j].operand[0], magic + 1, &k);
+        if (ret == 1 && k >= 0)
+        {
+          if (ops[k].op == OP_LEA) {
+            snprintf(buf, sizeof(buf), "arg_%X",
+              g_func_pp->argc_stack * 4);
+            if (!g_func_pp->is_vararg
+              || strstr(ops[k].operand[1].name, buf))
+            {
+              ops[k].flags |= OPF_RMD;
+              ops[j].flags |= OPF_RMD | OPF_VAPUSH;
+              save_args &= ~(1 << arg);
+              reg = -1;
+            }
+            else
+              ferr(&ops[j], "lea va_list used, but no vararg?\n");
+          }
+          // check for va_list from g_func_pp arg too
+          else if (ops[k].op == OP_MOV
+            && is_stack_access(&ops[k], &ops[k].operand[1]))
+          {
+            ret = stack_frame_access(&ops[k], &ops[k].operand[1],
+              buf, sizeof(buf), ops[k].operand[1].name, "", 1, 0);
+            if (ret >= 0) {
+              ops[k].flags |= OPF_RMD;
+              ops[j].flags |= OPF_RMD;
+              ops[j].p_argpass = ret + 1;
+              save_args &= ~(1 << arg);
+              reg = -1;
+            }
+          }
+        }
+      }
+
+      *save_arg_vars |= save_args;
+
+      // tracking reg usage
+      if (reg >= 0)
+        *regmask |= 1 << reg;
+
       arg++;
       if (!pp->is_unresolved) {
         // next arg
@@ -2746,10 +2852,6 @@ static int collect_call_args_r(struct parsed_op *po, int i,
             break;
       }
       magic = (magic & 0xffffff) | (arg << 24);
-
-      // tracking reg usage
-      if (ops[j].operand[0].type == OPT_REG)
-        *regmask |= 1 << ops[j].operand[0].reg;
     }
   }
 
@@ -3172,6 +3274,8 @@ tailcall:
         // indirect call
         pp_c = resolve_icall(i, opcnt, &l);
         if (pp_c != NULL) {
+          if (!pp_c->is_func && !pp_c->is_fptr)
+            ferr(po, "call to non-func: %s\n", pp_c->name);
           pp = proto_clone(pp_c);
           my_assert_not(pp, NULL);
           if (l)
@@ -3287,7 +3391,7 @@ tailcall:
         regmask_save |= 1 << reg;
     }
 
-    if (po->op == OP_PUSH && po->argnum == 0
+    if (po->op == OP_PUSH && po->p_argnum == 0
       && !(po->flags & OPF_RSAVE) && !g_func_pp->is_userstack)
     {
       if (po->operand[0].type == OPT_REG)
@@ -3395,8 +3499,12 @@ tailcall:
             pfomask = 1 << po->pfo;
           }
 
-          if (tmp_op->op == OP_ADD && po->pfo == PFO_C)
-            need_tmp64 = 1;
+          if (tmp_op->op == OP_ADD && po->pfo == PFO_C) {
+            propagate_lmod(tmp_op, &tmp_op->operand[0],
+              &tmp_op->operand[1]);
+            if (tmp_op->operand[0].lmod == OPLM_DWORD)
+              need_tmp64 = 1;
+          }
         }
         if (pfomask) {
           tmp_op->pfomask |= pfomask;
@@ -3417,7 +3525,8 @@ tailcall:
     else if (po->op == OP_MUL
       || (po->op == OP_IMUL && po->operand_cnt == 1))
     {
-      need_tmp64 = 1;
+      if (po->operand[0].lmod == OPLM_DWORD)
+        need_tmp64 = 1;
     }
     else if (po->op == OP_CALL) {
       pp = po->pp;
@@ -3438,7 +3547,7 @@ tailcall:
           tmp_op = pp->arg[arg].datap;
           if (tmp_op == NULL)
             ferr(po, "parsed_op missing for arg%d\n", arg);
-          if (tmp_op->argnum == 0 && tmp_op->operand[0].type == OPT_REG)
+          if (tmp_op->p_argnum == 0 && tmp_op->operand[0].type == OPT_REG)
             regmask_stack |= 1 << tmp_op->operand[0].reg;
         }
 
@@ -4203,13 +4312,22 @@ tailcall:
         assert_operand_cnt(2);
         propagate_lmod(po, &po->operand[0], &po->operand[1]);
         if (pfomask & (1 << PFO_C)) {
-          fprintf(fout, "  tmp64 = (u64)%s + %s;\n",
-            out_src_opr_u32(buf1, sizeof(buf1), po, &po->operand[0]),
-            out_src_opr_u32(buf2, sizeof(buf2), po, &po->operand[1]));
-          fprintf(fout, "  cond_c = tmp64 >> 32;\n");
-          fprintf(fout, "  %s = (u32)tmp64;",
-            out_dst_opr(buf1, sizeof(buf1), po, &po->operand[0]));
-          strcat(g_comment, "add64");
+          out_src_opr_u32(buf1, sizeof(buf1), po, &po->operand[0]);
+          out_src_opr_u32(buf2, sizeof(buf2), po, &po->operand[1]);
+          if (po->operand[0].lmod == OPLM_DWORD) {
+            fprintf(fout, "  tmp64 = (u64)%s + %s;\n", buf1, buf2);
+            fprintf(fout, "  cond_c = tmp64 >> 32;\n");
+            fprintf(fout, "  %s = (u32)tmp64;",
+              out_dst_opr(buf1, sizeof(buf1), po, &po->operand[0]));
+            strcat(g_comment, "add64");
+          }
+          else {
+            fprintf(fout, "  cond_c = ((u32)%s + %s) >> %d;\n",
+              buf1, buf2, lmod_bytes(po, po->operand[0].lmod) * 8);
+            fprintf(fout, "  %s += %s;",
+              out_dst_opr(buf1, sizeof(buf1), po, &po->operand[0]),
+              buf2);
+          }
           pfomask &= ~(1 << PFO_C);
           output_std_flags(fout, po, &pfomask, buf1);
           last_arith_dst = &po->operand[0];
@@ -4302,11 +4420,24 @@ tailcall:
         // fallthrough
       case OP_MUL:
         assert_operand_cnt(1);
-        strcpy(buf1, po->op == OP_IMUL ? "(s64)(s32)" : "(u64)");
-        fprintf(fout, "  tmp64 = %seax * %s%s;\n", buf1, buf1,
-          out_src_opr_u32(buf2, sizeof(buf2), po, &po->operand[0]));
-        fprintf(fout, "  edx = tmp64 >> 32;\n");
-        fprintf(fout, "  eax = tmp64;");
+        switch (po->operand[0].lmod) {
+        case OPLM_DWORD:
+          strcpy(buf1, po->op == OP_IMUL ? "(s64)(s32)" : "(u64)");
+          fprintf(fout, "  tmp64 = %seax * %s%s;\n", buf1, buf1,
+            out_src_opr_u32(buf2, sizeof(buf2), po, &po->operand[0]));
+          fprintf(fout, "  edx = tmp64 >> 32;\n");
+          fprintf(fout, "  eax = tmp64;");
+          break;
+        case OPLM_BYTE:
+          strcpy(buf1, po->op == OP_IMUL ? "(s16)(s8)" : "(u16)(u8)");
+          fprintf(fout, "  LOWORD(eax) = %seax * %s;", buf1,
+            out_src_opr(buf2, sizeof(buf2), po, &po->operand[0],
+              buf1, 0));
+          break;
+        default:
+          ferr(po, "TODO: unhandled mul type\n");
+          break;
+        }
         last_arith_dst = NULL;
         delayed_flag_op = NULL;
         break;
@@ -4500,8 +4631,15 @@ tailcall:
             tmp_op = pp->arg[arg].datap;
             if (tmp_op == NULL)
               ferr(po, "parsed_op missing for arg%d\n", arg);
-            if (tmp_op->argnum != 0) {
-              fprintf(fout, "%ss_a%d", cast, tmp_op->argnum);
+
+            if (tmp_op->flags & OPF_VAPUSH) {
+              fprintf(fout, "ap");
+            }
+            else if (tmp_op->p_argpass != 0) {
+              fprintf(fout, "a%d", tmp_op->p_argpass);
+            }
+            else if (tmp_op->p_argnum != 0) {
+              fprintf(fout, "%ss_a%d", cast, tmp_op->p_argnum);
             }
             else {
               fprintf(fout, "%s",
@@ -4589,9 +4727,9 @@ tailcall:
 
       case OP_PUSH:
         out_src_opr_u32(buf1, sizeof(buf1), po, &po->operand[0]);
-        if (po->argnum != 0) {
+        if (po->p_argnum != 0) {
           // special case - saved func arg
-          fprintf(fout, "  s_a%d = %s;", po->argnum, buf1);
+          fprintf(fout, "  s_a%d = %s;", po->p_argnum, buf1);
           break;
         }
         else if (po->flags & OPF_RSAVE) {