hotime/application.go
2021-09-12 05:35:14 +08:00

532 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package hotime
import (
. "./cache"
"./code"
. "./common"
. "./db"
. "./log"
"database/sql"
"github.com/sirupsen/logrus"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
)
type Application struct {
*code.MakeCode
MethodRouter
Router
ContextBase
Error
Log *logrus.Logger
WebConnectLog *logrus.Logger
Port string //端口号
TLSPort string //ssl访问端口号
connectListener []func(this *Context) bool //所有的访问监听,true按原计划继续使用false表示有监听器处理
connectDbFunc func(err ...*Error) (master, slave *sql.DB)
configPath string
Config Map
Db HoTimeDB
*HoTimeCache
*http.Server
http.Handler
}
func (that *Application) ServeHTTP(w http.ResponseWriter, req *http.Request) {
that.handler(w, req)
}
// Run 启动实例
func (that *Application) Run(router Router) {
//如果没有设置配置自动生成配置
if that.configPath == "" || len(that.Config) == 0 {
that.SetConfig()
}
//防止手动设置缓存误伤
if that.HoTimeCache == nil {
that.SetCache()
}
//防止手动设置session误伤
//if that.sessionShort == nil && that.sessionLong == nil {
// if that.connectDbFunc == nil {
// that.SetSession(CacheIns(&CacheMemory{}), nil)
// } else {
// that.SetSession(CacheIns(&CacheMemory{}), CacheIns(&CacheDb{Db: &that.Db, Time: that.Config.GetInt64("cacheLongTime")}))
// }
//
//}
that.Router = router
//重新设置MethodRouter//直达路由
that.MethodRouter = MethodRouter{}
modeRouterStrict := true
if that.Config.GetBool("modeRouterStrict") == false {
modeRouterStrict = false
}
if router != nil {
for pk, pv := range router {
if !modeRouterStrict {
pk = strings.ToLower(pk)
}
if pv != nil {
for ck, cv := range pv {
if !modeRouterStrict {
ck = strings.ToLower(ck)
}
if cv != nil {
for mk, mv := range cv {
if !modeRouterStrict {
mk = strings.ToLower(mk)
}
that.MethodRouter["/"+pk+"/"+ck+"/"+mk] = mv
}
}
}
}
}
}
//that.Port = port
that.Port = that.Config.GetString("port")
that.TLSPort = that.Config.GetString("tlsPort")
if that.connectDbFunc != nil && (that.Db.DB == nil || that.Db.DB.Ping() != nil) {
that.Db.SetConnect(that.connectDbFunc)
}
//异常处理
defer func() {
if err := recover(); err != nil {
//that.SetError(errors.New(fmt.Sprint(err)), LOG_FMT)
that.Log.Warn(err)
that.Run(router)
}
}()
that.Server = &http.Server{}
if !IsRun {
IsRun = true
}
ch := make(chan int)
if ObjToCeilInt(that.Port) != 0 {
go func() {
App[that.Port] = that
that.Server.Handler = that
//启动服务
that.Server.Addr = ":" + that.Port
err := that.Server.ListenAndServe()
that.Log.Error(err)
ch <- 1
}()
}
if ObjToCeilInt(that.TLSPort) != 0 {
go func() {
App[that.TLSPort] = that
that.Server.Handler = that
//启动服务
that.Server.Addr = ":" + that.TLSPort
err := that.Server.ListenAndServeTLS(that.Config.GetString("tlsCert"), that.Config.GetString("tlsKey"))
that.Log.Error(err)
ch <- 2
}()
}
if ObjToCeilInt(that.Port) == 0 && ObjToCeilInt(that.TLSPort) == 0 {
that.Log.Error("没有端口启用")
return
}
value := <-ch
that.Log.Error("启动服务失败 : " + ObjToStr(value))
}
// SetConnectDB 启动实例
func (that *Application) SetConnectDB(connect func(err ...*Error) (master, slave *sql.DB)) {
that.connectDbFunc = connect
that.Db.SetConnect(that.connectDbFunc, &that.Error)
}
// SetDefault 默认配置缓存和session实现
func (that *Application) SetDefault(connect func(err ...*Error) (*sql.DB, *sql.DB)) {
that.SetConfig()
if connect != nil {
that.connectDbFunc = connect
that.Db.SetConnect(that.connectDbFunc)
}
}
// SetCache 设置配置文件路径全路径或者相对路径
func (that *Application) SetCache() {
cacheIns := HoTimeCache{}
cacheIns.Init(that.Config.GetMap("cache"), HoTimeDBInterface(&that.Db), &that.Error)
that.HoTimeCache = &cacheIns
//mode生产模式开启的时候才开启数据库缓存防止调试出问题
if that.Config.GetInt("mode") == 0 {
that.Db.HoTimeCache = &cacheIns
}
}
// SetConfig 设置配置文件路径全路径或者相对路径
func (that *Application) SetConfig(configPath ...string) {
that.Log = GetLog("", true)
that.Error = Error{Logger: that.Log}
if len(configPath) != 0 {
that.configPath = configPath[0]
}
if that.configPath == "" {
that.configPath = "config/config.json"
}
//加载配置文件
btes, err := ioutil.ReadFile(that.configPath)
that.Config = DeepCopyMap(Config).(Map)
if err == nil {
cmap := Map{}
//文件是否损坏
cmap.JsonToMap(string(btes), &that.Error)
for k, v := range cmap {
that.Config[k] = v //程序配置
Config[k] = v //系统配置
}
} else {
that.Log.Error("配置文件不存在,或者配置出错,使用缺省默认配置")
}
that.Log = GetLog(that.Config.GetString("logFile"), true)
that.Error = Error{Logger: that.Log}
if that.Config.Get("webConnectLogShow") == nil || that.Config.GetBool("webConnectLogShow") {
that.WebConnectLog = GetLog(that.Config.GetString("webConnectLogFile"), false)
}
//文件如果损坏则不写入配置防止配置文件数据丢失
if that.Error.GetError() == nil {
//var configByte bytes.Buffer
//判断配置文件是否序列有变化,有则修改配置,无则不变
//fmt.Println(len(btes))
configStr := that.Config.ToJsonString()
if len(btes) != 0 && configStr == string(btes) {
return
}
//写入配置说明
//var configNoteByte bytes.Buffer
configNoteStr := ConfigNote.ToJsonString()
//_ = json.Indent(&configNoteByte, []byte(ConfigNote.ToJsonString()), "", "\t")
_ = os.MkdirAll(filepath.Dir(that.configPath), os.ModeDir)
err = ioutil.WriteFile(that.configPath, []byte(configStr), os.ModePerm)
if err != nil {
that.Error.SetError(err)
}
_ = ioutil.WriteFile(filepath.Dir(that.configPath)+"/configNote.json", []byte(configNoteStr), os.ModePerm)
}
}
// SetConnectListener 连接判断,返回true继续传输至控制层false则停止传输
func (that *Application) SetConnectListener(lis func(this *Context) bool) {
that.connectListener = append(that.connectListener, lis)
}
//网络错误
//func (this *Application) session(w http.ResponseWriter, req *http.Request) {
//
//}
//序列化链接
func (that *Application) urlSer(url string) (string, []string) {
q := strings.Index(url, "?")
if q == -1 {
q = len(url)
}
o := Substr(url, 0, q)
r := strings.SplitN(o, "/", -1)
var s = make([]string, 0)
for i := 0; i < len(r); i++ {
if !strings.EqualFold("", r[i]) {
s = append(s, r[i])
}
}
return o, s
}
//访问
func (that *Application) handler(w http.ResponseWriter, req *http.Request) {
_, s := that.urlSer(req.RequestURI)
//获取cookie
// 如果cookie存在直接将sessionId赋值为cookie.Value
// 如果cookie不存在就查找传入的参数中是否有token
// 如果token不存在就生成随机的sessionId
// 如果token存在就判断token是否在Session中有保存
// 如果有取出token并复制给cookie
// 没有保存就生成随机的session
cookie, err := req.Cookie(that.Config.GetString("sessionName"))
sessionId := Md5(strconv.Itoa(Rand(10)))
token := req.FormValue("token")
if err != nil || (len(token) == 32 && cookie.Value != token) {
if len(token) == 32 {
sessionId = token
}
//没有跨域设置
if that.Config.GetString("crossDomain") == "" {
http.SetCookie(w, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"})
} else {
//跨域允许需要设置cookie的允许跨域https才有效果
w.Header().Set("Set-Cookie", that.Config.GetString("sessionName")+"="+sessionId+"; SameSite=None; Secure")
}
} else {
sessionId = cookie.Value
}
unescapeUrl, err := url.QueryUnescape(req.RequestURI)
if err != nil {
unescapeUrl = req.RequestURI
}
//访问实例
context := Context{SessionIns: SessionIns{SessionId: sessionId, HoTimeCache: that.HoTimeCache},
Resp: w, Req: req, Application: that, RouterString: s, Config: that.Config, Db: &that.Db,
HandlerStr: unescapeUrl}
//header默认设置
header := w.Header()
header.Set("Content-Type", "text/html; charset=utf-8")
//url去掉参数并序列化
context.HandlerStr, context.RouterString = that.urlSer(context.HandlerStr)
//跨域设置
that.crossDomain(&context)
//是否展示日志
if that.WebConnectLog != nil {
that.WebConnectLog.Infoln(Substr(context.Req.RemoteAddr, 0, strings.Index(context.Req.RemoteAddr, ":")), context.Req.Method, context.HandlerStr)
}
//访问拦截true继续false暂停
connectListenerLen := len(that.connectListener)
if connectListenerLen != 0 {
for i := 0; i < connectListenerLen; i++ {
if !that.connectListener[i](&context) {
context.View()
return
}
}
}
//接口服务
//验证接口严格模式
modeRouterStrict := that.Config.GetBool("modeRouterStrict")
tempHandlerStr := context.HandlerStr
if !modeRouterStrict {
tempHandlerStr = strings.ToLower(tempHandlerStr)
}
//执行接口
if that.MethodRouter[tempHandlerStr] != nil {
that.MethodRouter[tempHandlerStr](&context)
context.View()
return
}
//url赋值
path := that.Config.GetString("tpt") + tempHandlerStr
//判断是否为默认
if path[len(path)-1] == '/' {
defFile := that.Config.GetSlice("defFile")
for i := 0; i < len(defFile); i++ {
temp := path + defFile.GetString(i)
_, err := os.Stat(temp)
if err == nil {
path = temp
break
}
}
if path[len(path)-1] == '/' {
w.WriteHeader(404)
return
}
}
if strings.Contains(path, "/.") {
w.WriteHeader(404)
return
}
//设置header
delete(header, "Content-Type")
if that.Config.GetInt("mode") == 0 {
header.Set("Cache-Control", "public")
} else {
header.Set("Cache-Control", "no-cache")
}
if strings.Index(path, ".m3u8") != -1 {
header.Add("Content-Type", "audio/mpegurl")
}
//w.Write(data)
http.ServeFile(w, req, path)
}
func (that *Application) crossDomain(context *Context) {
//没有跨域设置
if context.Config.GetString("crossDomain") == "" {
return
}
header := context.Resp.Header()
//header.Set("Access-Control-Allow-Origin", "*")
header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,PUT,DELETE")
header.Set("Access-Control-Allow-Credentials", "true")
header.Set("Access-Control-Expose-Headers", "*")
header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token")
if context.Config.GetString("crossDomain") != "auto" {
header.Set("Access-Control-Allow-Origin", that.Config.GetString("crossDomain"))
// 后端设置2592000单位秒这里是30天
header.Set("Access-Control-Max-Age", "2592000")
return
}
origin := context.Req.Header.Get("Origin")
if origin != "" {
header.Set("Access-Control-Allow-Origin", origin)
return
}
refer := context.Req.Header.Get("Referer")
if refer != "" {
tempInt := 0
lastInt := strings.IndexFunc(refer, func(r rune) bool {
if r == '/' && tempInt > 8 {
return true
}
tempInt++
return false
})
if lastInt < 0 {
lastInt = len(refer)
}
refer = Substr(refer, 0, lastInt)
header.Set("Access-Control-Allow-Origin", refer)
}
}
//Init 初始化application
func Init(config string) Application {
appIns := Application{}
//手动模式,
appIns.SetConfig(config)
SetDB(&appIns)
appIns.SetCache()
appIns.MakeCode = &code.MakeCode{}
codeConfig := appIns.Config.GetMap("codeConfig")
if codeConfig != nil {
for k, _ := range codeConfig {
if appIns.Config.GetInt("mode") == 2{
appIns.MakeCode.Db2JSON(k, codeConfig.GetString(k), &appIns.Db)
}else{
appIns.MakeCode.Db2JSON(k, codeConfig.GetString(k), nil)
}
}
}
return appIns
}
// SetDB 智能数据库设置
func SetDB(appIns *Application) {
db := appIns.Config.GetMap("db")
dbSqlite := db.GetMap("sqlite")
dbMysql := db.GetMap("mysql")
if db != nil && dbSqlite != nil {
SetSqliteDB(appIns, dbSqlite)
}
if db != nil && dbMysql != nil {
SetMysqlDB(appIns, dbMysql)
}
}
func SetMysqlDB(appIns *Application, config Map) {
appIns.Db.Type = "mysql"
appIns.Db.DBName = config.GetString("name")
appIns.Db.Prefix = config.GetString("prefix")
appIns.SetConnectDB(func(err ...*Error) (master, slave *sql.DB) {
//master数据库配置
query := config.GetString("user") + ":" + config.GetString("password") +
"@tcp(" + config.GetString("host") + ":" + config.GetString("port") + ")/" + config.GetString("name") + "?charset=utf8"
DB, e := sql.Open("mysql", query)
if e != nil && len(err) != 0 {
err[0].SetError(e)
}
master = DB
//slave数据库配置
configSlave := config.GetMap("slave")
if configSlave != nil {
query := configSlave.GetString("user") + ":" + configSlave.GetString("password") +
"@tcp(" + config.GetString("host") + ":" + configSlave.GetString("port") + ")/" + configSlave.GetString("name") + "?charset=utf8"
DB1, e := sql.Open("mysql", query)
if e != nil && len(err) != 0 {
err[0].SetError(e)
}
slave = DB1
}
return master, slave
//return DB
})
}
func SetSqliteDB(appIns *Application, config Map) {
appIns.Db.Type = "sqlite"
appIns.Db.Prefix = config.GetString("prefix")
appIns.SetConnectDB(func(err ...*Error) (master, slave *sql.DB) {
db, e := sql.Open("sqlite3", config.GetString("path"))
if e != nil && len(err) != 0 {
err[0].SetError(e)
}
master = db
return master, slave
})
}