/*
 * Platform detection for I/O multiplexing.
 *
 * To manually override, add one of these at the very top of this file:
 *   #define USE_KQUEUE    (requires kqueue support: macOS, FreeBSD, etc.)
 *   #define USE_PPOLL     (requires ppoll support: Linux, etc.)
 *
 * Note: Ensure the chosen method is available on your platform, or you'll
 * get compilation errors.
 */
#if !defined(USE_KQUEUE) && !defined(USE_PPOLL)
# if defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__)
#  define USE_KQUEUE
# else
#  define USE_PPOLL
# endif
#endif

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#ifdef USE_KQUEUE
#include <sys/event.h>
#endif
#ifdef USE_PPOLL
#include <poll.h>
#endif
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>
#include <stdint.h>
#include <getopt.h>
#include <pthread.h>
#include <math.h>

#include "libpq-fe.h"

#define PG_TIME_GET_DOUBLE(t) (0.000001 * (t))
#define CONNECTION_STRING "postgresql:///postgres"
#define MAX_NOTIFIERS 256  /* Maximum notifiers per channel for sequence tracking */

/* Latency histogram buckets */
#define NUM_BUCKETS 6
static uint64_t bucket_counts[NUM_BUCKETS];
static uint64_t bucket_totals[NUM_BUCKETS];  /* Total latency in microseconds */
static pthread_mutex_t histogram_mutex = PTHREAD_MUTEX_INITIALIZER;

static uint32_t num_notifies_sent;
static uint32_t num_notifies_received;
static volatile int start_notifying = 0;  /* Signal for notifiers to start */

/* Synchronization for listener LISTEN setup */
static int listeners_ready = 0;
static pthread_mutex_t startup_mutex = PTHREAD_MUTEX_INITIALIZER;
static pthread_cond_t startup_cond = PTHREAD_COND_INITIALIZER;

/* --extra-channels argument */
static int			num_extra_channels = 0;

/* Thread arguments structure */
struct thread_args {
	int channel_id;
	int notifier_id;  /* ID of notifier within channel (0 to num_notify_threads-1) */
	int num_notifiers;  /* Total number of notifiers per channel */
	int total_listeners;  /* Total number of listener threads (all channels) */
	uint64_t *seq_counter;  /* Pointer to this notifier's sequence counter */
};

typedef int64_t pg_time_usec_t;

static inline pg_time_usec_t
pg_time_now(void)
{
	struct timeval tv;

	gettimeofday(&tv, NULL);

	return (pg_time_usec_t) (tv.tv_sec * 1000000 + tv.tv_usec);
}

static void
exit_nicely(PGconn *conn)
{
	PQfinish(conn);
	exit(1);
}

static int
get_latency_bucket(double latency_ms)
{
	/* Buckets: 0-0.01ms, 0.01-0.1ms, 0.1-1ms, 1-10ms, 10-100ms, >100ms */
	if (latency_ms < 0.01)
		return 0;
	else if (latency_ms < 0.1)
		return 1;
	else if (latency_ms < 1.0)
		return 2;
	else if (latency_ms < 10.0)
		return 3;
	else if (latency_ms < 100.0)
		return 4;
	else
		return 5;
}

static void
update_histogram(uint64_t latency_usec)
{
	double latency_ms = latency_usec / 1000.0;
	int bucket = get_latency_bucket(latency_ms);

	pthread_mutex_lock(&histogram_mutex);
	bucket_counts[bucket]++;
	bucket_totals[bucket] += latency_usec;
	pthread_mutex_unlock(&histogram_mutex);
}

static void *
notify_thread_main(void *arg)
{
	struct thread_args *args = (struct thread_args *)arg;
	int channel_id = args->channel_id;
	int notifier_id = args->notifier_id;
	uint64_t *seq_counter = args->seq_counter;
	PGconn	   *conn;
	PGresult   *res;
	char		channel_name[32];

	/* Generate channel name from channel_id */
	snprintf(channel_name, sizeof(channel_name), "%d", channel_id);

	/* Make a connection to the database */
	conn = PQconnectdb(CONNECTION_STRING);

	/* Check to see that the backend connection was successfully made */
	if (PQstatus(conn) != CONNECTION_OK)
	{
		fprintf(stderr, "%s", PQerrorMessage(conn));
		exit_nicely(conn);
	}

	/* Wait for signal to start notifying */
	while (!start_notifying)
		usleep(10000);  /* Sleep 10ms */

	for(;;)
	{
		char buf[128];
		pg_time_usec_t send_time;
		uint64_t seq;

		/* Get timestamp before sending */
		send_time = pg_time_now();

		/* Atomically get and increment this notifier's sequence counter */
		seq = __sync_fetch_and_add(seq_counter, 1);

		/* Send notification with notifier_id, sequence number, and timestamp */
		snprintf(buf, sizeof(buf), "NOTIFY \"%s\", '%d %lld %lld'",
				 channel_name, notifier_id, (long long)seq, (long long)send_time);
		res = PQexec(conn, buf);
		if (PQresultStatus(res) != PGRES_COMMAND_OK)
		{
			fprintf(stderr, "NOTIFY command failed: %s", PQerrorMessage(conn));
			PQclear(res);
			exit_nicely(conn);
		}
		PQclear(res);

		__sync_fetch_and_add(&num_notifies_sent, 1);
	}
}

static void *
listen_thread_main(void *arg)
{
	struct thread_args *args = (struct thread_args *)arg;
	int channel_id = args->channel_id;
	int num_notifiers = args->num_notifiers;
	PGconn	   *conn;
	PGresult   *res;
	PGnotify   *notify;
	uint64_t	expected_seq[MAX_NOTIFIERS];
	char		channel_name[32];
	char		listen_cmd[64];

	/* Initialize expected sequence for each notifier */
	for (int i = 0; i < MAX_NOTIFIERS; i++)
		expected_seq[i] = 0;

	/* Generate channel name from channel_id */
	snprintf(channel_name, sizeof(channel_name), "%d", channel_id);

	/* Make a connection to the database */
	conn = PQconnectdb(CONNECTION_STRING);

	/* Check to see that the backend connection was successfully made */
	if (PQstatus(conn) != CONNECTION_OK)
	{
		fprintf(stderr, "%s", PQerrorMessage(conn));
		exit_nicely(conn);
	}


	/*
	 * Issue LISTEN command for "extra" channels. The extra channels are never
	 * notified, they're used just to bloat the list of channels that notify
	 * processing needs to traverse.
	 */
	for (int i = 0; i < num_extra_channels; i++)
	{
		snprintf(listen_cmd, sizeof(listen_cmd), "LISTEN \"extra%d\"", i);
		res = PQexec(conn, listen_cmd);
		if (PQresultStatus(res) != PGRES_COMMAND_OK)
		{
			fprintf(stderr, "LISTEN command failed: %s", PQerrorMessage(conn));
			PQclear(res);
			exit_nicely(conn);
		}
		PQclear(res);
	}

	/*
	 * Issue LISTEN command to enable notifications from the rule's NOTIFY.
	 */
	snprintf(listen_cmd, sizeof(listen_cmd), "LISTEN \"%s\"", channel_name);
	res = PQexec(conn, listen_cmd);
	if (PQresultStatus(res) != PGRES_COMMAND_OK)
	{
		fprintf(stderr, "LISTEN command failed: %s", PQerrorMessage(conn));
		PQclear(res);
		exit_nicely(conn);
	}
	PQclear(res);

	/* Signal that this listener is ready */
	pthread_mutex_lock(&startup_mutex);
	listeners_ready++;
	if (listeners_ready == args->total_listeners)
		pthread_cond_signal(&startup_cond);  /* Wake main thread if we're the last */
	pthread_mutex_unlock(&startup_mutex);

	for (;;)
	{
		/*
		 * Sleep until something happens on the connection.  We use kqueue(2)
		 * on macOS/BSD for better performance and scalability, or ppoll(2) on
		 * other platforms. Both avoid the FD_SETSIZE limitation of select().
		 */
		int			sock;

		sock = PQsocket(conn);

		if (sock < 0)
			break;				/* shouldn't happen */

#ifdef USE_KQUEUE
		/* Use kqueue for better performance and scalability */
		int kq;
		struct kevent kev;
		struct kevent event;

		kq = kqueue();
		if (kq < 0)
		{
			fprintf(stderr, "kqueue() failed: %s\n", strerror(errno));
			exit_nicely(conn);
		}

		/* Monitor the socket for read events */
		EV_SET(&kev, sock, EVFILT_READ, EV_ADD | EV_ONESHOT, 0, 0, NULL);

		/* Wait indefinitely for an event */
		if (kevent(kq, &kev, 1, &event, 1, NULL) < 0)
		{
			fprintf(stderr, "kevent() failed: %s\n", strerror(errno));
			close(kq);
			exit_nicely(conn);
		}

		close(kq);
#elif defined(USE_PPOLL)
		/* Use ppoll (nanosecond resolution) */
		struct pollfd pfd;

		pfd.fd = sock;
		pfd.events = POLLIN;
		pfd.revents = 0;

		if (ppoll(&pfd, 1, NULL, NULL) < 0)
		{
			fprintf(stderr, "ppoll() failed: %s\n", strerror(errno));
			exit_nicely(conn);
		}
#else
# error "No I/O multiplexing method defined. Define USE_KQUEUE or USE_PPOLL."
#endif

		/* Now check for input */
		PQconsumeInput(conn);
		while ((notify = PQnotifies(conn)) != NULL)
		{
			pg_time_usec_t recv_time;
			pg_time_usec_t send_time;
			int notifier_id;
			uint64_t seq;
			uint64_t latency_usec;

			/* Get receive timestamp */
			recv_time = pg_time_now();

			/* Parse notifier_id, sequence number, and send timestamp from payload */
			if (notify->extra && notify->extra[0])
			{
				if (sscanf(notify->extra, "%d %lld %lld", &notifier_id, (long long *)&seq, (long long *)&send_time) == 3)
				{
					/* Validate notifier_id */
					if (notifier_id < 0 || notifier_id >= num_notifiers)
					{
						fprintf(stderr, "\nERROR: Channel %d received invalid notifier_id %d (expected 0-%d)\n",
								channel_id, notifier_id, num_notifiers - 1);
						abort();
					}

					/* Verify sequence number for this notifier */
					if (seq != expected_seq[notifier_id])
					{
						fprintf(stderr, "\nERROR: Channel %d notifier %d sequence gap! Expected %lld, received %lld\n",
								channel_id, notifier_id, (long long)expected_seq[notifier_id], (long long)seq);
						abort();
					}
					expected_seq[notifier_id]++;

					latency_usec = recv_time - send_time;

					/* Update histogram */
					update_histogram(latency_usec);
				}
			}

			PQfreemem(notify);
			PQconsumeInput(conn);

			__sync_fetch_and_add(&num_notifies_received, 1);
		}
	}

	return NULL;
}

int
main(int argc, char **argv)
{
	int			num_threads = 0;
	pthread_t  *threads;
	pg_time_usec_t start;
	static struct option long_options[] = {
		/* systematic long/short named options */
		{"listeners", required_argument, NULL, 1},
		{"notifiers", required_argument, NULL, 2},
		{"channels", required_argument, NULL, 3},
		{"extra-channels", required_argument, NULL, 4},
		{NULL, 0, NULL, 0}
	};
	int			num_listen_threads = 1;
	int			num_notify_threads = 1;
	int			num_channels = 1;
	int			optindex;
	int			c;

	while ((c = getopt_long(argc, argv, "", long_options, &optindex)) != -1)
	{
		switch (c)
		{
			case 1:				/* listeners */
				num_listen_threads = atoi(optarg);
				if (num_listen_threads < 1)
				{
					fprintf(stderr, "invalid --listeners argument\n");
					exit(1);
				}
				break;

			case 2:				/* notifiers */
				num_notify_threads = atoi(optarg);
				if (num_notify_threads < 1)
				{
					fprintf(stderr, "invalid --notifiers argument\n");
					exit(1);
				}
				break;

			case 3:				/* channels */
				num_channels = atoi(optarg);
				if (num_channels < 1)
				{
					fprintf(stderr, "invalid --channels argument\n");
					exit(1);
				}
				break;

			case 4:				/* extra-channels */
				num_extra_channels = atoi(optarg);
				if (num_extra_channels < 0)
				{
					fprintf(stderr, "invalid --extra-channels argument\n");
					exit(1);
				}
				break;
		}
	}

	int total_threads = num_channels * (num_notify_threads + num_listen_threads);
	threads = malloc(total_threads * sizeof(pthread_t));
	struct thread_args *thread_args_array = malloc(total_threads * sizeof(struct thread_args));

	/* Allocate sequence counters for each notifier thread (initialized to 0) */
	uint64_t *notifier_seqs = calloc(num_channels * num_notify_threads, sizeof(uint64_t));

	/* Spawn threads for each channel */
	for (int channel_id = 0; channel_id < num_channels; channel_id++)
	{
		/* Spawn notifier threads for this channel */
		for (int i = 0; i < num_notify_threads; i++)
		{
			int			s;
			int			notifier_index = channel_id * num_notify_threads + i;

			thread_args_array[num_threads].channel_id = channel_id;
			thread_args_array[num_threads].notifier_id = i;
			thread_args_array[num_threads].num_notifiers = num_notify_threads;
			thread_args_array[num_threads].total_listeners = num_channels * num_listen_threads;
			thread_args_array[num_threads].seq_counter = &notifier_seqs[notifier_index];
			s = pthread_create(&threads[num_threads], NULL,
							   &notify_thread_main, &thread_args_array[num_threads]);
			if (s != 0)
			{
				fprintf(stderr, "pthread_create failed\n");
				exit(1);
			}
			num_threads++;
		}

		/* Spawn listener threads for this channel */
		for (int i = 0; i < num_listen_threads; i++)
		{
			int			s;

			thread_args_array[num_threads].channel_id = channel_id;
			thread_args_array[num_threads].notifier_id = -1;  /* Not used for listeners */
			thread_args_array[num_threads].num_notifiers = num_notify_threads;
			thread_args_array[num_threads].total_listeners = num_channels * num_listen_threads;
			thread_args_array[num_threads].seq_counter = NULL;  /* Not used for listeners */
			s = pthread_create(&threads[num_threads], NULL,
							   &listen_thread_main, &thread_args_array[num_threads]);
			if (s != 0)
			{
				fprintf(stderr, "pthread_create failed\n");
				exit(1);
			}
			num_threads++;
		}
	}

	/* Wait for all listeners to establish LISTEN before notifiers start sending */
	pthread_mutex_lock(&startup_mutex);
	while (listeners_ready < num_channels * num_listen_threads)
		pthread_cond_wait(&startup_cond, &startup_mutex);
	pthread_mutex_unlock(&startup_mutex);

	/* Signal notifiers to start */
	start_notifying = 1;

	start = pg_time_now();

	uint32_t prev_sent = 0;
	uint32_t prev_received = 0;
	int first_iteration = 1;

	for (;;)
	{
		double		elapsed_sec;
		uint32_t	curr_sent;
		uint32_t	curr_received;
		uint32_t	sent_per_sec;
		uint32_t	received_per_sec;

		sleep(1);

		/* Move cursor back up before printing (except first time) */
		if (!first_iteration)
			fprintf(stderr, "\033[%dA\r", NUM_BUCKETS + 1);
		first_iteration = 0;

		elapsed_sec = PG_TIME_GET_DOUBLE(pg_time_now() - start);

		curr_sent = num_notifies_sent;
		curr_received = num_notifies_received;
		sent_per_sec = curr_sent - prev_sent;
		received_per_sec = curr_received - prev_received;

		/* Print stats on same line */
		fprintf(stderr, "\r%.0f s: %u sent (%u/s), %u received (%u/s)    ",
				elapsed_sec, curr_sent, sent_per_sec, curr_received, received_per_sec);

		/* Print histogram */
		pthread_mutex_lock(&histogram_mutex);

		uint64_t total_measured = 0;
		for (int i = 0; i < NUM_BUCKETS; i++)
			total_measured += bucket_counts[i];

		if (total_measured > 0)
		{
			const char *bucket_labels[] = {
				" 0.00-0.01ms   ",
				" 0.01-0.10ms   ",
				" 0.10-1.00ms   ",
				" 1.00-10.00ms  ",
				" 10.00-100.00ms",
				">100.00ms     "
			};

			fprintf(stderr, "\n");
			for (int i = 0; i < NUM_BUCKETS; i++)
			{
				uint64_t count = bucket_counts[i];
				double percentage = (count * 100.0) / total_measured;
				double avg_latency_ms = 0.0;

				if (count > 0)
					avg_latency_ms = (bucket_totals[i] / 1000.0) / count;

				/* Draw bar chart (max 10 chars) */
				int bar_length = (int)((count * 10) / total_measured);
				if (bar_length == 0 && count > 0)
					bar_length = 1;

				fprintf(stderr, "%s  ", bucket_labels[i]);
				for (int j = 0; j < bar_length; j++)
					fprintf(stderr, "#");
				for (int j = bar_length; j < 10; j++)
					fprintf(stderr, " ");

				fprintf(stderr, " %llu (%.1f%%) avg: %.3fms\n",
						(unsigned long long)count, percentage, avg_latency_ms);
			}
		}

		pthread_mutex_unlock(&histogram_mutex);
		fflush(stderr);

		prev_sent = curr_sent;
		prev_received = curr_received;
	}

	return 0;
}
