package net.herit.svcplatform.pushservice.commons.dto.wrapper;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;

import org.springframework.lang.Nullable;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.ContentCachingRequestWrapper;

public class RequestWrapper extends ContentCachingRequestWrapper {
	private final ByteArrayOutputStream cachedContent;

	@Nullable
	private final Integer contentCacheLimit;

	@Nullable
	private ServletInputStream inputStream;

	@Nullable
	private BufferedReader reader;

	private ByteArrayInputStream bis;
	private byte[] httpRequestBodyByteArray;

	public RequestWrapper(HttpServletRequest request) throws IOException {
		super(request);
		int contentLength = request.getContentLength();

		this.httpRequestBodyByteArray = StreamUtils.copyToByteArray(request.getInputStream());
		this.bis = new ByteArrayInputStream(httpRequestBodyByteArray);

		this.cachedContent = new ByteArrayOutputStream(contentLength >= 0 ? contentLength : 1024);
		this.contentCacheLimit = null;

	}

	public RequestWrapper(HttpServletRequest request, int contentCacheLimit) throws IOException {
		super(request);

		this.cachedContent = new ByteArrayOutputStream(contentCacheLimit);
		this.contentCacheLimit = contentCacheLimit;

		this.httpRequestBodyByteArray = StreamUtils.copyToByteArray(request.getInputStream());
		this.bis = new ByteArrayInputStream(httpRequestBodyByteArray);
	}

	@Override
	public ServletInputStream getInputStream() throws IOException {

		if (this.inputStream == null) {
			this.inputStream = new ContentCachingInputStream(servletInputStream());
		}

		return servletInputStream();
	}

	public byte[] getRequestBodyByteArray() {
		return httpRequestBodyByteArray;
	}

	public ByteArrayInputStream getByteArrayInputStream() {
		return bis;
	}

	private class ContentCachingInputStream extends ServletInputStream {

		private final ServletInputStream is;

		private boolean overflow = false;

		public ContentCachingInputStream(ServletInputStream is) {
			this.is = is;
		}

		@Override
		public int read() throws IOException {
			int ch = this.is.read();
			if (ch != -1 && !this.overflow) {
				if (contentCacheLimit != null && cachedContent.size() == contentCacheLimit) {
					this.overflow = true;
					handleContentOverflow(contentCacheLimit);
				} else {
					cachedContent.write(ch);
				}
			}
			return ch;
		}

		@Override
		public int read(byte[] b) throws IOException {
			int count = this.is.read(b);
			writeToCache(b, 0, count);
			return count;
		}

		private void writeToCache(final byte[] b, final int off, int count) {
			if (!this.overflow && count > 0) {
				if (contentCacheLimit != null && count + cachedContent.size() > contentCacheLimit) {
					this.overflow = true;
					cachedContent.write(b, off, contentCacheLimit - cachedContent.size());
					handleContentOverflow(contentCacheLimit);
					return;
				}
				cachedContent.write(b, off, count);
			}
		}

		@Override
		public int read(final byte[] b, final int off, final int len) throws IOException {
			int count = this.is.read(b, off, len);
			writeToCache(b, off, count);
			return count;
		}

		@Override
		public int readLine(final byte[] b, final int off, final int len) throws IOException {
			int count = this.is.readLine(b, off, len);
			writeToCache(b, off, count);
			return count;
		}

		@Override
		public boolean isFinished() {
			return this.is.isFinished();
		}

		@Override
		public boolean isReady() {
			return this.is.isReady();
		}

		@Override
		public void setReadListener(ReadListener readListener) {
			this.is.setReadListener(readListener);
		}
	}

	private ServletInputStream servletInputStream() {
		return new ServletInputStream() {
			private int lastIndexRetrieved = -1;
			private ReadListener readListener = null;

			@Override
			public boolean isFinished() {
				return (lastIndexRetrieved == httpRequestBodyByteArray.length - 1);
			}

			@Override
			public boolean isReady() {
				return isFinished();
			}

			@Override
			public void setReadListener(ReadListener readListener) {
				this.readListener = readListener;
				if (!isFinished()) {
					try {
						readListener.onDataAvailable();
					} catch (IOException e) {
						readListener.onError(e);
					}
				} else {
					try {
						readListener.onAllDataRead();
					} catch (IOException e) {
						readListener.onError(e);
					}
				}
			}

			@Override
			public int read() throws IOException {
				int i;
				if (!isFinished()) {
					i = httpRequestBodyByteArray[lastIndexRetrieved + 1];
					lastIndexRetrieved++;
					if (isFinished() && (readListener != null)) {
						try {
							readListener.onAllDataRead();
						} catch (IOException ex) {
							readListener.onError(ex);
							throw ex;
						}
					}
					return i;
				} else {
					return -1;
				}
			}

			@Override
			public int available() throws IOException {
				return (httpRequestBodyByteArray.length - lastIndexRetrieved - 1);
			}

			@Override
			public void close() throws IOException {
				lastIndexRetrieved = httpRequestBodyByteArray.length - 1;
			}
		};
	}
}
