#include <sys/param.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <assert.h>
#include <err.h>
#include <fcntl.h>
#include <libutil.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

size_t pagesize[2];

static void
check_region(char *start, size_t len, const char *msg)
{
	size_t offset;
	int i, leading, trailing, pages, super, small;

	offset = (uintptr_t)start % pagesize[0];
	if (offset != 0) {
		printf("%s: start address is not page aligned!\n", msg);
		len += offset;
		start -= offset;
	}
	offset = (uintptr_t)start % pagesize[1];
	if (offset != 0) {
		printf("%s: start address is not super aligned\n", msg);
		leading = (pagesize[1] - offset) / pagesize[0];
	} else
		leading = 0;
	len = roundup2(len, pagesize[0]);
	trailing = (len - leading * pagesize[0]) % pagesize[1] / pagesize[0];

	small = super = 0;
	pages = len / pagesize[0];
	char *info = malloc(pages);
	if (mincore(start, len, info) < 0)
		err(1, "%s: mincore", msg);
	for (i = 0; i < pages; i++) {
		if (info[i] & MINCORE_SUPER)
			super++;
		else if (info[i] & MINCORE_INCORE)
			small++;
	}
	assert(super % (pagesize[1] / pagesize[0]) == 0);
	if (small != (leading + trailing))
		printf("%s: expected %d super / %d small, found %d / %d\n",
		    msg, pages - (leading + trailing), leading + trailing,
		    super, small);
}

static int
create_shm(size_t size)
{
	int fd;

	fd = shm_open(SHM_ANON, O_RDWR, 0644);
	if (fd < 0)
		err(1, "shm_open(SHM_ANON)");
	if (ftruncate(fd, size) < 0)
		err(1, "ftruncate");
	return (fd);
}

static int
create_file(const char *path)
{
	int fd;

	fd = open(path, O_RDWR | O_CREAT | O_TRUNC, 0644);
	if (fd < 0)
		err(1, "open(%s)", path);
	return (fd);
}

static void
zero_file(int fd, size_t size)
{
	ssize_t nwritten;
	size_t towrite;
	char buffer[64 * 1024];

	memset(buffer, 0, sizeof(buffer));
	if (ftruncate(fd, 0) < 0)
		err(1, "ftruncate");
	while (size > 0) {
		towrite = size;
		if (towrite > sizeof(buffer))
			towrite = sizeof(buffer);
		nwritten = write(fd, buffer, towrite);
		if (nwritten < 0)
			err(1, "write");
		if ((size_t)nwritten != towrite)
			errx(1, "short write");
		size -= towrite;
	}
}

static void
map_file(int fd, bool prefault, const char *msg)
{
	struct stat sb;
	char *cp, *p;
	int sum;

	if (fstat(fd, &sb) < 0)
		err(1, "fstat");
	p = mmap(NULL, sb.st_size, PROT_READ, MAP_SHARED | MAP_PREFAULT_READ,
	    fd, 0);
	if (p == MAP_FAILED)
		err(1, "mmap");
	if (prefault) {
		sum = 0;
		for (cp = p; cp < p + sb.st_size; cp += pagesize[0])
			sum += *cp;
		usleep(sum);
	}
	check_region(p, sb.st_size, msg);
}

int
main(int ac, char **av)
{
	char buf[5];
	int retval;
	int fd;

	retval = getpagesizes(pagesize, 2);
	if (retval < 0)
		err(1, "getpagesizes");
	if (retval != 2)
		errx(1, "getpagesizes returned %d sizes", retval);
	(void)humanize_number(buf, sizeof(buf), pagesize[0], "", HN_AUTOSCALE,
	    HN_NOSPACE);
	printf("Normal page size: %s\n", buf);
	(void)humanize_number(buf, sizeof(buf), pagesize[1], "", HN_AUTOSCALE,
	    HN_NOSPACE);
	printf("Super page size: %s\n", buf);

	/* First, a super-page sized shm. */
	map_file(create_shm(pagesize[1] * 4), true, "super page sized shm");

	/*
	 * Now, try a shm that isn't an even multiple followed by one
	 * that is.
	 */
	map_file(create_shm(pagesize[1] * 2 + pagesize[1] / 2), true,
	    "odd sized shm");
	map_file(create_shm(pagesize[1] * 2), true, "super page sized shm (2)");

	/*
	 * Create a new file on disk and map it without pre-zeroing.
	 */
	fd = create_file("/usr/scratch/spagefile");
	if (ftruncate(fd, pagesize[1] * 32) < 0)
		err(1, "ftruncate");
	map_file(fd, true, "truncated file");

	/* Truncate and zero the file and then map it. */
	zero_file(fd, pagesize[1] * 64);
	map_file(fd, true, "zeroed file");

	return (0);
}
