From 9c00ac6ba1c480a3bf17b5a6cd39e55cc778a742 Mon Sep 17 00:00:00 2001 From: hoteas Date: Tue, 28 Dec 2021 09:26:26 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=87=AA=E9=80=82=E5=BA=94?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- application.go | 55 ++++++++++++++++++++++++++++-------------- example/tpt/index.html | 2 +- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/application.go b/application.go index 52e2882..2577ae0 100644 --- a/application.go +++ b/application.go @@ -302,6 +302,7 @@ func (that *Application) handler(w http.ResponseWriter, req *http.Request) { // 没有保存就生成随机的session cookie, err := req.Cookie(that.Config.GetString("sessionName")) sessionId := Md5(strconv.Itoa(Rand(10))) + needSetCookie := "" token := req.Header.Get("Authorization") if len(token) != 32 { @@ -314,16 +315,9 @@ func (that *Application) handler(w http.ResponseWriter, req *http.Request) { //没有token,则查阅session } else if err == nil && cookie.Value != "" { sessionId = cookie.Value - //session也没有则判断是否创建cookie } else { - //跨域不再通过cookie校验 - if that.Config.GetString("crossDomain") == "" { - http.SetCookie(w, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) - } else { - //跨域允许需要设置cookie的允许跨域https才有效果 - w.Header().Set("Set-Cookie", that.Config.GetString("sessionName")+"="+sessionId+"; Path=/; SameSite=None; Secure") - } + needSetCookie = sessionId } unescapeUrl, err := url.QueryUnescape(req.RequestURI) @@ -342,7 +336,7 @@ func (that *Application) handler(w http.ResponseWriter, req *http.Request) { context.HandlerStr, context.RouterString = that.urlSer(context.HandlerStr) //跨域设置 - that.crossDomain(&context) + that.crossDomain(&context, needSetCookie) defer func() { //是否展示日志 @@ -356,6 +350,7 @@ func (that *Application) handler(w http.ResponseWriter, req *http.Request) { ipStr = req.Header.Get("X-Real-IP") } } + that.WebConnectLog.Infoln(ipStr, context.Req.Method, "time cost:", ObjToFloat64(time.Now().UnixNano()-nowUnixTime.UnixNano())/1000000.00, "ms", "data length:", ObjToFloat64(context.DataSize)/1000.00, "KB", context.HandlerStr) @@ -434,9 +429,14 @@ func (that *Application) handler(w http.ResponseWriter, req *http.Request) { } -func (that *Application) crossDomain(context *Context) { +func (that *Application) crossDomain(context *Context, sessionId string) { + //没有跨域设置 if context.Config.GetString("crossDomain") == "" { + if sessionId != "" { + http.SetCookie(context.Resp, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) + } + return } @@ -450,6 +450,10 @@ func (that *Application) crossDomain(context *Context) { if context.Config.GetString("crossDomain") != "auto" { //不跨域,则不设置 if strings.Contains(context.Config.GetString("crossDomain"), remoteHost) { + + if sessionId != "" { + http.SetCookie(context.Resp, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) + } return } header.Set("Access-Control-Allow-Origin", that.Config.GetString("crossDomain")) @@ -462,22 +466,29 @@ func (that *Application) crossDomain(context *Context) { header.Set("Access-Control-Expose-Headers", "*") header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token") + if sessionId != "" { + //跨域允许需要设置cookie的允许跨域https才有效果 + context.Resp.Header().Set("Set-Cookie", that.Config.GetString("sessionName")+"="+sessionId+"; Path=/; SameSite=None; Secure") + } return } origin := context.Req.Header.Get("Origin") refer := context.Req.Header.Get("Referer") - if strings.Contains(origin, remoteHost) || strings.Contains(refer, remoteHost) { + if (origin != "" && strings.Contains(origin, remoteHost)) || strings.Contains(refer, remoteHost) { + + if sessionId != "" { + http.SetCookie(context.Resp, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) + } + return } if origin != "" { header.Set("Access-Control-Allow-Origin", origin) //return - } - - if refer != "" { + } else if refer != "" { tempInt := 0 lastInt := strings.IndexFunc(refer, func(r rune) bool { if r == '/' && tempInt > 8 { @@ -493,11 +504,19 @@ func (that *Application) crossDomain(context *Context) { refer = Substr(refer, 0, lastInt) header.Set("Access-Control-Allow-Origin", refer) //header.Set("Access-Control-Allow-Origin", "*") - header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,PUT,DELETE") - header.Set("Access-Control-Allow-Credentials", "true") - header.Set("Access-Control-Expose-Headers", "*") - header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token") + } + + header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,PUT,DELETE") + header.Set("Access-Control-Allow-Credentials", "true") + header.Set("Access-Control-Expose-Headers", "*") + header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token") + + if sessionId != "" { + //跨域允许需要设置cookie的允许跨域https才有效果 + context.Resp.Header().Set("Set-Cookie", that.Config.GetString("sessionName")+"="+sessionId+"; Path=/; SameSite=None; Secure") + } + } //Init 初始化application diff --git a/example/tpt/index.html b/example/tpt/index.html index bec9c55..56e91ba 100644 --- a/example/tpt/index.html +++ b/example/tpt/index.html @@ -1,3 +1,3 @@
\ No newline at end of file + }
\ No newline at end of file