iedc-go/application.go
2021-05-24 06:14:58 +08:00

457 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package hotime
import (
"bytes"
"database/sql"
"encoding/json"
"github.com/sirupsen/logrus"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
)
type Application struct {
MethodRouter
Router
contextBase
Log logrus.Logger
Port string //端口号
TLSPort string //ssl访问端口号
connectListener []func(this *Context) bool //所有的访问监听,true按原计划继续使用false表示有监听器处理
connectDbFunc func(err ...*Error) (master, slave *sql.DB)
configPath string
Config Map
Db HoTimeDB
Server *http.Server
CacheIns
sessionLong CacheIns
sessionShort CacheIns
http.Handler
}
func (this *Application) ServeHTTP(w http.ResponseWriter, req *http.Request) {
this.handler(w, req)
}
//启动实例
func (this *Application) Run(router Router) {
//如果没有设置配置自动生成配置
if this.configPath == "" || len(this.Config) == 0 {
this.SetConfig()
}
//防止手动设置缓存误伤
if this.CacheIns == nil {
this.SetCache(CacheIns(&CacheMemory{}))
}
//防止手动设置session误伤
if this.sessionShort == nil && this.sessionLong == nil {
if this.connectDbFunc == nil {
this.SetSession(CacheIns(&CacheMemory{}), nil)
} else {
this.SetSession(CacheIns(&CacheMemory{}), CacheIns(&CacheDb{Db: &this.Db, Time: this.Config.GetInt64("cacheLongTime")}))
}
}
this.Router = router
//重新设置MethodRouter//直达路由
this.MethodRouter = MethodRouter{}
modeRouterStrict := true
if this.Config.Get("modeRouterStrict").(bool) == false {
modeRouterStrict = false
}
if router != nil {
for pk, pv := range router {
if !modeRouterStrict {
pk = strings.ToLower(pk)
}
if pv != nil {
for ck, cv := range pv {
if !modeRouterStrict {
ck = strings.ToLower(ck)
}
if cv != nil {
for mk, mv := range cv {
if !modeRouterStrict {
mk = strings.ToLower(mk)
}
this.MethodRouter["/"+pk+"/"+ck+"/"+mk] = mv
}
}
}
}
}
}
//this.Port = port
this.Port = this.Config.GetString("port")
this.TLSPort = this.Config.GetString("tlsPort")
if this.connectDbFunc != nil && (this.Db.DB == nil || this.Db.DB.Ping() != nil) {
this.Db.SetConnect(this.connectDbFunc)
}
if this.CacheIns == nil {
this.CacheIns = CacheIns(&CacheMemory{Map: Map{}, Time: this.Config.GetInt64("cacheShortTime")})
}
//异常处理
defer func() {
if err := recover(); err != nil {
//this.SetError(errors.New(fmt.Sprint(err)), LOG_FMT)
logFmt(err, 2, LOG_ERROR)
this.Run(router)
}
}()
this.Server = &http.Server{}
if !IsRun {
IsRun = true
}
ch := make(chan int)
if ObjToCeilInt(this.Port) != 0 {
go func() {
App[this.Port] = this
this.Server.Handler = this
//启动服务
this.Server.Addr = ":" + this.Port
err := this.Server.ListenAndServe()
logFmt(err, 2)
ch <- 1
}()
}
if ObjToCeilInt(this.TLSPort) != 0 {
go func() {
App[this.TLSPort] = this
this.Server.Handler = this
//启动服务
this.Server.Addr = ":" + this.TLSPort
err := this.Server.ListenAndServeTLS(this.Config.GetString("tlsCert"), this.Config.GetString("tlsKey"))
logFmt(err, 2)
ch <- 2
}()
}
if ObjToCeilInt(this.Port) == 0 && ObjToCeilInt(this.TLSPort) == 0 {
logFmt("没有端口启用", 2, LOG_INFO)
return
}
value := <-ch
logFmt("启动服务失败 : "+ObjToStr(value), 2, LOG_ERROR)
}
//启动实例
func (this *Application) SetConnectDB(connect func(err ...*Error) (master, slave *sql.DB)) {
this.connectDbFunc = connect
this.Db.SetConnect(this.connectDbFunc)
}
//设置配置文件路径全路径或者相对路径
func (this *Application) SetSession(short CacheIns, Long CacheIns) {
this.sessionLong = Long
this.sessionShort = short
}
//默认配置缓存和session实现
func (this *Application) SetDefault(connect func(err ...*Error) (*sql.DB, *sql.DB)) {
this.SetConfig()
if connect != nil {
this.connectDbFunc = connect
this.Db.SetConnect(this.connectDbFunc)
}
}
//设置配置文件路径全路径或者相对路径
func (this *Application) SetCache(cache CacheIns) {
this.CacheIns = cache
}
//设置配置文件路径全路径或者相对路径
func (this *Application) SetConfig(configPath ...string) {
if len(configPath) != 0 {
this.configPath = configPath[0]
}
if this.configPath == "" {
this.configPath = "config/config.json"
}
//加载配置文件
btes, err := ioutil.ReadFile(this.configPath)
this.Config = DeepCopyMap(Config).(Map)
if err == nil {
cmap := Map{}
//文件是否损坏
cmap.JsonToMap(string(btes), &this.Error)
for k, v := range cmap {
this.Config[k] = v //程序配置
Config[k] = v //系统配置
}
} else {
logFmt("配置文件不存在,或者配置出错,使用缺省默认配置", 2)
}
//文件如果损坏则不写入配置防止配置文件数据丢失
if this.Error.GetError() == nil {
var configByte bytes.Buffer
err = json.Indent(&configByte, []byte(this.Config.ToJsonString()), "", "\t")
//判断配置文件是否序列有变化有则修改配置,五则不变
//fmt.Println(len(btes))
if len(btes) != 0 && configByte.String() == string(btes) {
return
}
//写入配置说明
var configNoteByte bytes.Buffer
json.Indent(&configNoteByte, []byte(ConfigNote.ToJsonString()), "", "\t")
os.MkdirAll(filepath.Dir(this.configPath), os.ModeDir)
err = ioutil.WriteFile(this.configPath, configByte.Bytes(), os.ModeAppend)
if err != nil {
this.Error.SetError(err)
}
ioutil.WriteFile(filepath.Dir(this.configPath)+"/confignote.json", configNoteByte.Bytes(), os.ModeAppend)
}
}
//连接判断,返回true继续传输至控制层false则停止传输
func (this *Application) SetConnectListener(lis func(this *Context) bool) {
this.connectListener = append(this.connectListener, lis)
}
//网络错误
//func (this *Application) session(w http.ResponseWriter, req *http.Request) {
//
//}
//序列化链接
func (this *Application) urlSer(url string) (string, []string) {
q := strings.Index(url, "?")
if q == -1 {
q = len(url)
}
o := Substr(url, 0, q)
r := strings.SplitN(o, "/", -1)
var s = make([]string, 0)
for i := 0; i < len(r); i++ {
if !strings.EqualFold("", r[i]) {
s = append(s, r[i])
}
}
return o, s
}
//访问
func (this *Application) handler(w http.ResponseWriter, req *http.Request) {
_, s := this.urlSer(req.RequestURI)
//获取cookie
// 如果cookie存在直接将sessionId赋值为cookie.Value
// 如果cookie不存在就查找传入的参数中是否有token
// 如果token不存在就生成随机的sessionId
// 如果token存在就判断token是否在Session中有保存
// 如果有取出token并复制给cookie
// 没有保存就生成随机的session
cookie, err := req.Cookie(this.Config.GetString("sessionName"))
sessionId := Md5(strconv.Itoa(Rand(10)))
token := req.FormValue("token")
//isFirst:=false
if err != nil || (len(token) == 32 && cookie.Value != token) {
if len(token) == 32 {
sessionId = token
}
//else{
// isFirst=true;
//}
http.SetCookie(w, &http.Cookie{Name: this.Config.GetString("sessionName"), Value: sessionId, Path: "/"})
} else {
sessionId = cookie.Value
}
unescapeUrl, err := url.QueryUnescape(req.RequestURI)
if err != nil {
unescapeUrl = req.RequestURI
}
//访问实例
context := Context{SessionIns: SessionIns{SessionId: sessionId,
LongCache: this.sessionLong,
ShortCache: this.sessionShort,
},
CacheIns: this.CacheIns,
Resp: w, Req: req, Application: this, RouterString: s, Config: this.Config, Db: &this.Db, HandlerStr: unescapeUrl}
//header默认设置
header := w.Header()
header.Set("Content-Type", "text/html; charset=utf-8")
//url去掉参数并序列化
context.HandlerStr, context.RouterString = this.urlSer(context.HandlerStr)
//跨域设置
this.crossDomain(&context)
//是否展示日志
if this.Config.GetInt("connectLogShow") != 0 {
logFmt(Substr(context.Req.RemoteAddr, 0, strings.Index(context.Req.RemoteAddr, ":"))+" "+context.HandlerStr, 0, LOG_INFO)
}
//访问拦截true继续false暂停
connectListenerLen := len(this.connectListener)
if connectListenerLen != 0 {
for i := 0; i < connectListenerLen; i++ {
if !this.connectListener[i](&context) {
context.View()
return
}
}
}
//接口服务
//if len(s) == 3 {
// //如果满足规则则路由到对应控制器去
// if this.Router[s[0]] != nil && this.Router[s[0]][s[1]] != nil && this.Router[s[0]][s[1]][s[2]] != nil {
// //控制层
// this.Router[s[0]][s[1]][s[2]](&context)
// //header.Set("Content-Type", "text/html; charset=utf-8")
// context.View()
// return
// }
//
//}
//验证接口严格模式
modeRouterStrict := this.Config.Get("modeRouterStrict").(bool)
tempHandlerStr := context.HandlerStr
if !modeRouterStrict {
tempHandlerStr = strings.ToLower(tempHandlerStr)
}
//执行接口
if this.MethodRouter[tempHandlerStr] != nil {
this.MethodRouter[tempHandlerStr](&context)
context.View()
return
}
//url赋值
path := this.Config.GetString("tpt") + tempHandlerStr
//判断是否为默认
if path[len(path)-1] == '/' {
defFile := this.Config.GetSlice("defFile")
for i := 0; i < len(defFile); i++ {
temp := path + defFile.GetString(i)
_, err := os.Stat(temp)
if err == nil {
path = temp
break
}
}
if path[len(path)-1] == '/' {
w.WriteHeader(404)
return
}
}
if strings.Contains(path, "/.") {
w.WriteHeader(404)
return
}
//设置header
delete(header, "Content-Type")
if this.Config.GetInt("debug") != 1 {
header.Set("Cache-Control", "public")
}
if strings.Index(path, ".m3u8") != -1 {
header.Add("Content-Type", "audio/mpegurl")
}
//w.Write(data)
http.ServeFile(w, req, path)
}
func (this *Application) crossDomain(context *Context) {
//没有跨域设置
if context.Config.GetString("crossDomain") == "" {
return
}
header := context.Resp.Header()
//header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "*")
header.Set("Access-Control-Allow-Credentials", "true")
header.Set("Access-Control-Expose-Headers", "*")
header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token")
if context.Config.GetString("crossDomain") != "auto" {
header.Set("Access-Control-Allow-Origin", this.Config.GetString("crossDomain"))
return
}
origin := context.Req.Header.Get("Origin")
if origin != "" {
header.Set("Access-Control-Allow-Origin", origin)
return
}
refer := context.Req.Header.Get("Referer")
if refer != "" {
tempInt := 0
lastInt := strings.IndexFunc(refer, func(r rune) bool {
if r == '/' && tempInt > 8 {
return true
}
tempInt++
return false
})
if lastInt < 0 {
lastInt = len(refer)
}
refer = Substr(refer, 0, lastInt)
header.Set("Access-Control-Allow-Origin", refer)
}
}
func Init(config string) Application {
appIns := Application{}
//手动模式,
appIns.SetConfig(config)
SetDB(&appIns)
//appIns.SetCache()
return appIns
}