/* vim:set ts=4 sts=4 sw=4 noet fenc=utf-8:

   Copyright 2009 senju@users.sourceforge.jp

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
 */

package jp.sourceforge.rabbitBTS.interceptors;

import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import jp.sourceforge.rabbitBTS.Sht;
import jp.sourceforge.rabbitBTS.controllers.IController;

import org.apache.commons.lang.RandomStringUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

/**
 * CSRF対策用インターセプター
 */
public class CSRFInterceptor extends HandlerInterceptorAdapter {
	private int expireInSecond;

	/**
	 * POSTの場合チェック処理
	 */
	@Override
	public boolean preHandle(HttpServletRequest request,
			HttpServletResponse response, Object handler) throws Exception {
		if (request.getMethod().equals("POST")
				&& handler instanceof IController) {
			final IController c = (IController) handler;
			final CsrfChecker checker = new CsrfChecker(request);
			if (checker.checkTokenValid()) {
				c.setCsrfSafe(true);
			} else {
				c.setCsrfSafe(false);
				Sht.log(this).warn("CSRF detected.");
			}
		}
		return true;
	}

	/**
	 * セッションとMAVにトークンを格納しておく
	 */
	@Override
	public void postHandle(HttpServletRequest request,
			HttpServletResponse response, Object handler, ModelAndView mav)
			throws Exception {
		// リダイレクトする場合は不要
		if (mav != null) {
			if (!StringUtils.startsWith(mav.getViewName(), "redirect:")) {
				final CsrfChecker checker = new CsrfChecker(request);
				final String token = checker.saveNewToken();
				mav.addObject("secureToken", token);
			}
		}

		if (request.getMethod().equals("POST")
				&& handler instanceof IController) {
			// きちんとCSRFチェックが行われているかチェックする
			final IController c = (IController) handler;
			if (!c.isCsrfChecked()) {
				Sht.log(this).error("CSRFチェックを行っていないPOST");
			}
			assert c.isCsrfChecked() : "CSRFチェックを行っていないPOST";
		}
	}

	/**
	 * チェック用クラス
	 */
	class CsrfChecker {
		private final HttpServletRequest req;
		private Map<String, Date> tokens;

		/**
		 * コンストラクタ
		 * 
		 * <p>
		 * セッションにトークンのリストが存在しない場合、新規に作成しセッションに保存する。
		 * 
		 * @param request
		 */
		@SuppressWarnings("unchecked")
		public CsrfChecker(HttpServletRequest request) {
			this.req = request;
			this.tokens = (Map<String, Date>) request.getSession()
					.getAttribute("tokens");
			if (this.tokens == null) {
				this.tokens = new HashMap<String, Date>();
				request.getSession().setAttribute("tokens", this.tokens);
			}
		}

		/**
		 * トークンをセッションに保存する。
		 * 
		 * @return 保存された新規トークン
		 */
		public String saveNewToken() {
			final String token = RandomStringUtils.randomAlphanumeric(128);
			this.tokens.put(token, new Date());
			return token;
		}

		/**
		 * チェックを行う
		 * 
		 * @return 正しいパラメータが送信された場合true
		 */
		public boolean checkTokenValid() {
			final String reqToken = this.req.getParameter("secureToken");
			final Date datenow = new Date();
			// トークンチェック
			boolean found = false;
			if (this.tokens.containsKey(reqToken)) {
				found = checkExpire(datenow, this.tokens.get(reqToken));
			}

			// 削除チェック
			if (this.tokens.size() > 30) {
				final List<String> removeList = new ArrayList<String>();
				for (final String token : this.tokens.keySet()) {
					final Date created = this.tokens.get(token);
					if (!this.checkExpire(datenow, created)) {
						// 期限切れの場合
						removeList.add(token);
						Sht.log(this).trace("delete token:{}", token);
					}
				}
				for (final String token : removeList) {
					this.tokens.remove(token);
				}
			}

			return found;
		}

		/**
		 * 期限切れかチェックする。
		 * 
		 * @param datenow
		 * @param created
		 * @return 期限以内の場合true
		 */
		private boolean checkExpire(Date datenow, Date created) {
			final long ageInMils = datenow.getTime() - created.getTime();
			return ageInMils / 1000 < CSRFInterceptor.this.expireInSecond;
		}
	}

	/**
	 * @param expireInSecond
	 *            the expireInSecond to set
	 */
	public void setExpireInSecond(int expireInMinute) {
		this.expireInSecond = expireInMinute;
	}

}
