package cors_test

import (
	"bytes"
	"code.justin.tv/infosec/cors"
	"net/http"
	"net/http/httptest"
	"net/http/httputil"
	"os"
	"time"
)

var CORSPolicy = cors.Policy{
	AllowedOrigins: cors.Origins("localhost", "cool.com"),

	MaxAge: cors.MaxAge(5 * time.Hour),
}

func ContentHandler(w http.ResponseWriter, r *http.Request) {
	w.Write([]byte("hello"))
}

var mainHandler http.Handler

func init() {
	// if OPTIONS happens, we block out the ContentHandler
	// for the preflight request
	mainHandler =
		CORSPolicy.MustMiddleware(
			cors.BlockOnOptions(
				http.HandlerFunc(ContentHandler)))
}

func Example() {
	s := httptest.NewServer(mainHandler)

	// the rest is just demonstration

	//preflight request
	pfReq, err := http.NewRequest("OPTIONS", s.URL, nil)
	if err != nil {
		panic(err)
	}

	stdReq, err := http.NewRequest("GET", s.URL, nil)
	if err != nil {
		return
	}

	for _, v := range [...]*http.Request{pfReq, stdReq} {

		v.Host = "some host"
		v.Header.Set("Origin", "cool.com")

		rqBt, err := httputil.DumpRequest(v, true)
		if err != nil {
			return
		}

		rqBt = bytes.Replace(rqBt, []byte("\r\n"), []byte("\n"), -1)

		_, err = os.Stdout.Write(rqBt)
		if err != nil {
			return
		}

		rs, err := http.DefaultClient.Do(v)
		if err != nil {
			panic(err)
		}

		rs.Header.Set("Date", "some date")

		rsBt, err := httputil.DumpResponse(rs, true)
		if err != nil {
			panic(err)
		}

		rsBt = bytes.Replace(rsBt, []byte("\r\n"), []byte("\n"), -1)

		_, err = os.Stdout.Write(rsBt)
		if err != nil {
			panic(err)
		}
	}

	//Output:
	//OPTIONS / HTTP/1.1
	//Host: some host
	//Origin: cool.com
	//
	//HTTP/1.1 200 OK
	//Access-Control-Allow-Origin: cool.com
	//Access-Control-Max-Age: 18000
	//Content-Type: text/plain; charset=utf-8
	//Date: some date
	//Content-Length: 0
	//
	//GET / HTTP/1.1
	//Host: some host
	//Origin: cool.com
	//
	//HTTP/1.1 200 OK
	//Content-Length: 5
	//Access-Control-Allow-Origin: cool.com
	//Content-Type: text/plain; charset=utf-8
	//Date: some date
	//
	//hello
}
