/*
 * subst.c - Substitution rules
 *
 * Copyright 2012 by Werner Almesberger
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 */


#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <ctype.h>
#include <sys/types.h>
#include <regex.h>
#include <assert.h>

#include "util.h"
#include "vstring.h"
#include "lang.h"
#include "relop.h"
#include "subst.h"


const char *fn;


/* ----- Rule set construction --------------------------------------------- */


static struct subst *alloc_subst(enum subst_type type)
{
	struct subst *sub;

	sub = alloc_type(struct subst);
	sub->type = type;
	sub->next = NULL;
	return sub;
}


/*
 * With M the SI multiplier prefixes and U the unit character, our regexp
 * is
 *
 * (-?[0-9]+\.?[[0-9]*M?U?|-?[0-9]+[UM][0-9]*)
 *
 * The first part is for things like 10, 1.2k, 3.3V, -2mA, etc.
 * The second part is for things like 1k20, 1R2, etc.
 */

static void unit_expr(char **res, int *res_len, char unit)
{
	append(res, res_len, "(-?[0-9]+\\.?[0-9]*[" MULT_CHARS "]?");
	if (unit != '#')
		append_char(res, res_len, unit);
	append(res, res_len, "?|-?[0-9]+[");
	if (unit != '#')
		append_char(res, res_len, unit);
	append(res, res_len, MULT_CHARS "][0-9]*)");
}


static char *prepare_re(const char *re, char *units)
{
	char *res = NULL;
	int res_len = 0;
	int parens = 0;

	memset(units, 0, 10);
	append_char(&res, &res_len, '^');
	while (*re) {
		switch (*re) {
		case '.':
			append_n(&res, &res_len, "\\.", 2);
			break;
		case '*':
			append_n(&res, &res_len, ".*", 2);
			break;
		case '?':
			append_char(&res, &res_len, '.');
			break;
		case '\\':
			if (!re[1])
				yyerrorf("regexp ends with backslash");
			append_n(&res, &res_len, re, 2);
			re++;
			break;
		case '(':
			parens++;
			if (re[1] == '#' && re[2]) {
				if ((!isalpha(re[2]) &&
				    re[2] != '%'  && re[2] != '#') ||
				    re[3] != ')')
					yyerrorf("invalid (#unit) syntax");
				units[parens-1] = re[2];
				unit_expr(&res, &res_len, re[2]);
				re += 3;
				break;
			}
			/* fall through */
		default:
			append_char(&res, &res_len, *re);
		}
		re++;
	}
	append(&res, &res_len, re);
	append_char(&res, &res_len, '$');
	return res;
}


struct subst *subst_match(const char *src, const char *re, char **res)
{
	char error[1000];
	struct subst *sub;
	char *tmp;
	int err;

	sub = alloc_subst(st_match);
	sub->u.match.src = src;
	tmp = prepare_re(re, sub->u.match.units);
	err = regcomp(&sub->u.match.re, tmp, REG_EXTENDED);
	if (res)
		*res = tmp;
	else
		free(tmp);
	if (err) {
		regerror(err, &sub->u.match.re, error, sizeof(error));
		yyerrorf("%s", error);
	}
	return sub;
}


static void end_chunk(struct chunk ***last, const char *start, const char *s)
{
	struct chunk *c;

	if (s == start)
		return;

	c = alloc_type(struct chunk);
	c->type = ct_string;
	c->u.s = stralloc_n(start, s-start);;
	c->next = NULL;
	**last = c;
	*last = &c->next;
}


static const char *parse_var(struct chunk *c, const char *s)
{
	const char *t;
	int braced;

	if (!*s)
		yyerror("trailing dollar sign");

	braced = *s == '{';
	if (braced)
		s++;

	t = s;
	while (*t) {
		if (braced && *t == '}')
			break;
		if (s == t && *t == '$') {
			t++;
			break;
		}
		if (!isalnum(*t))
			break;
		t++;
	}
	if (s == t)
		yyerror("invalid variable name");
	if (braced && !*t)
		yyerror("unterminated variable name");
	if (isdigit(*s)) {
		if (t != s+1 || *s == '0')
			yyerror("invalid variable name");
		c->type = ct_sub;
		c->u.sub = *s-'0';
	} else if (isalnum(*s)) {
		c->type = ct_var;
		c->u.var = unique_n(s, t-s);
	} else {
		c->type = ct_sub;
		c->u.sub = 0;
	}

	if (braced) {
		if (*t != '}')
			yyerror("invalid variable name");
		t++;
	}
	return t;
}


static struct chunk *parse_pattern(const char *s)
{
	struct chunk *res = NULL, **last = &res;
	struct chunk *c;
	const char *start = s;

	while (*s) {
		if (*s == '\\') {
			if (!s[1])
				yyerror("trailing backslash");
			end_chunk(&last, start, s);
			start = s+1;
			s += 2;
			continue;
		}
		if (*s != '$') {
			s++;
			continue;
		}

		end_chunk(&last, start, s);
		c = alloc_type(struct chunk);
		c->next = NULL;
		*last = c;
		last = &c->next;
		start = s = parse_var(c, s+1);
	}
	end_chunk(&last, start, s);
	return res;
}


struct subst *subst_assign(const char *dst, enum relop op, const char *pat)
{
	struct subst *sub;

	if (dst == fn)
		yyerror("can't assign to pseudo-variable FN");
	sub = alloc_subst(st_assign);
	sub->u.assign.dst = dst;
	sub->u.assign.op = op;
	sub->u.assign.pat = parse_pattern(pat);
	return sub;
}


struct subst *subst_print(const char *var)
{
	struct subst *sub;

	sub = alloc_subst(st_print);
	sub->u.print = var;
	return sub;
}


struct subst *subst_end(void)
{
	return alloc_subst(st_end);
}


struct subst *subst_ignore(void)
{
	return alloc_subst(st_ignore);
}


struct subst *subst_break(const char *block)
{
	struct subst *sub;

	sub = alloc_subst(st_break);
	sub->u.tmp = block;
	return sub;
}


struct subst *subst_continue(const char *block)
{
	struct subst *sub;

	sub = alloc_subst(st_continue);
	sub->u.tmp = block;
	return sub;
}


/* ----- Jump resolution --------------------------------------------------- */


struct parent {
	const struct subst *sub;
	const struct parent *parent;
};


static const struct subst *resolve_jump(const char *name,
    const struct parent *parent)
{
	if (!name)
		return parent->sub;
	while (parent) {
		assert(parent->sub->type == st_match);
		if (name == parent->sub->u.match.src)
			return parent->sub;
		parent = parent->parent;
	}
	yyerrorf("cannot find \"%s\"", name);
}


static int find_var_use(const char *var, const struct subst *sub)
{
	while (sub) {
		switch (sub->type) {
		case st_match:
			if (sub->u.match.src == var && var != fn)
				return 1;
			break;
		case st_assign:
			if (sub->u.assign.dst == var)
				return 1;
			break;
		default:
			break;
		}
		sub = sub->prev;
	}
	return 0;
}


static void check_chunks(const struct chunk *c, const struct parent *parent,
    const struct subst *prev)
{
	int parens;

	while (c) {
		switch (c->type) {
		case ct_sub:
			if (!parent)
				yyerrorf("$%c without match",
				    c->u.sub ? c->u.sub+'0' : '$');
			parens = parent->sub->u.match.re.re_nsub;
			if (c->u.sub > parens)
				yyerrorf("$%d but only %d parenthes%s",
				    c->u.sub, parens,
				    parens == 1 ? "is" : "es");
			break;
		case ct_var:
			if (!find_var_use(c->u.var, prev))
				yyerrorf("$%s may be undefined", c->u.var);
			break;
		default:
			break;
		}
		c = c->next;
	}
}


static void recurse_fin(struct subst *sub, const struct parent *parent)
{
	struct parent next = {
		.parent = parent,
	};
	const struct subst *prev;

	prev = parent ? parent->sub : NULL;
	while (sub) {
		sub->prev = prev;
		switch (sub->type) {
		case st_match:
			if (!parent && sub->u.match.src == dollar)
				yyerror("$ without match");
			next.sub = sub;
			recurse_fin(sub->u.match.block, &next);
			break;
		case st_assign:
			if (!parent && sub->u.assign.dst == dollar)
				yyerror("$ without match");
			check_chunks(sub->u.assign.pat, parent, prev);
			break;
		case st_print:
			break;
		case st_end:
			break;
		case st_ignore:
			break;
		case st_break:
			/* fall through */
		case st_continue:
			if (!parent)
				yyerror("jump without block");
			sub->u.jump = resolve_jump(sub->u.tmp, parent);
			break;
		default:
			abort();
		}
		prev = sub;
		sub = sub->next;
	}
}


void subst_finalize(struct subst *sub)
{
	recurse_fin(sub, NULL);
}


/* ----- Dumping ----------------------------------------------------------- */


#define	INDENT	4


static void dump_chunks(FILE *file, const struct chunk *c)
{
	while (c) {
		switch (c->type) {
		case ct_string:
			fprintf(file, "%s", c->u.s);
			break;
		case ct_var:
			fprintf(file, "${%s}", c->u.var);
			break;
		case ct_sub:
			if (c->u.sub)
				fprintf(file, "$%d", c->u.sub);
			else
				fprintf(file, "$$");
			break;
		default:
			abort();
		}
		c = c->next;
	}
}


static void recurse_dump(FILE *file, const struct subst *sub, int level)
{
	while (sub) {
		fprintf(file, "%*s", INDENT*level, "");
		switch (sub->type) {
		case st_match:
			fprintf(file, "%s=RE {\n", sub->u.match.src);
			recurse_dump(file, sub->u.match.block, level+1);
			fprintf(file, "%*s}\n", INDENT*level, "");
			break;
		case st_assign:
			fprintf(file, "%s", sub->u.assign.dst);
			dump_relop(file, sub->u.assign.op);
			dump_chunks(file, sub->u.assign.pat);
			fprintf(file, "\n");
			break;
		case st_print:
			fprintf(file, "print %s\n", sub->u.print);
			break;
		case st_end:
			fprintf(file, "end\n");
			break;
		case st_ignore:
			fprintf(file, "ignore\n");
			break;
		case st_break:
			fprintf(file, "break %s\n", sub->u.jump->u.match.src);
			break;
		case st_continue:
			fprintf(file, "continue %s\n",
			    sub->u.jump->u.match.src);
			break;
		default:
			abort();
		}
		sub = sub->next;
	}
}


void subst_dump(FILE *file, const struct subst *sub)
{
	recurse_dump(file, sub, 0);
}


void subst_init(void)
{
	fn = unique("FN");
}