package hotime import ( . "./cache" "./code" . "./common" . "./db" . "./log" "database/sql" "github.com/sirupsen/logrus" "io" "io/ioutil" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "time" ) type Application struct { *code.MakeCode MakeCodeRouter Router MethodRouter Router ContextBase Error Log *logrus.Logger WebConnectLog *logrus.Logger Port string //端口号 TLSPort string //ssl访问端口号 connectListener []func(this *Context) bool //所有的访问监听,true按原计划继续使用,false表示有监听器处理 connectDbFunc func(err ...*Error) (master, slave *sql.DB) configPath string Config Map Db HoTimeDB *HoTimeCache *http.Server http.Handler } func (that *Application) ServeHTTP(w http.ResponseWriter, req *http.Request) { that.handler(w, req) } // Run 启动实例 func (that *Application) Run(router Router) { //如果没有设置配置自动生成配置 if that.configPath == "" || len(that.Config) == 0 { that.SetConfig() } //防止手动设置缓存误伤 if that.HoTimeCache == nil { that.SetCache() } //防止手动设置session误伤 //if that.sessionShort == nil && that.sessionLong == nil { // if that.connectDbFunc == nil { // that.SetSession(CacheIns(&CacheMemory{}), nil) // } else { // that.SetSession(CacheIns(&CacheMemory{}), CacheIns(&CacheDb{Db: &that.Db, Time: that.Config.GetInt64("cacheLongTime")})) // } // //} //that.Router = router if that.Router == nil { that.Router = Router{} } for k, v := range router { that.Router[k] = v } //重新设置MethodRouter//直达路由 that.MethodRouter = MethodRouter{} modeRouterStrict := true if that.Config.GetBool("modeRouterStrict") == false { modeRouterStrict = false } if that.Router != nil { for pk, pv := range that.Router { if !modeRouterStrict { pk = strings.ToLower(pk) } if pv != nil { for ck, cv := range pv { if !modeRouterStrict { ck = strings.ToLower(ck) } if cv != nil { for mk, mv := range cv { if !modeRouterStrict { mk = strings.ToLower(mk) } that.MethodRouter["/"+pk+"/"+ck+"/"+mk] = mv } } } } } } //that.Port = port that.Port = that.Config.GetString("port") that.TLSPort = that.Config.GetString("tlsPort") if that.connectDbFunc != nil && (that.Db.DB == nil || that.Db.DB.Ping() != nil) { that.Db.SetConnect(that.connectDbFunc) } //异常处理 defer func() { if err := recover(); err != nil { //that.SetError(errors.New(fmt.Sprint(err)), LOG_FMT) that.Log.Warn(err) that.Run(router) } }() that.Server = &http.Server{} if !IsRun { IsRun = true } ch := make(chan int) if ObjToCeilInt(that.Port) != 0 { go func() { App[that.Port] = that that.Server.Handler = that //启动服务 that.Server.Addr = ":" + that.Port err := that.Server.ListenAndServe() that.Log.Error(err) ch <- 1 }() } if ObjToCeilInt(that.TLSPort) != 0 { go func() { App[that.TLSPort] = that that.Server.Handler = that //启动服务 that.Server.Addr = ":" + that.TLSPort err := that.Server.ListenAndServeTLS(that.Config.GetString("tlsCert"), that.Config.GetString("tlsKey")) that.Log.Error(err) ch <- 2 }() } if ObjToCeilInt(that.Port) == 0 && ObjToCeilInt(that.TLSPort) == 0 { that.Log.Error("没有端口启用") return } value := <-ch that.Log.Error("启动服务失败 : " + ObjToStr(value)) } // SetConnectDB 启动实例 func (that *Application) SetConnectDB(connect func(err ...*Error) (master, slave *sql.DB)) { that.connectDbFunc = connect that.Db.SetConnect(that.connectDbFunc, &that.Error) } // SetDefault 默认配置缓存和session实现 func (that *Application) SetDefault(connect func(err ...*Error) (*sql.DB, *sql.DB)) { that.SetConfig() if connect != nil { that.connectDbFunc = connect that.Db.SetConnect(that.connectDbFunc) } } // SetCache 设置配置文件路径全路径或者相对路径 func (that *Application) SetCache() { cacheIns := HoTimeCache{} cacheIns.Init(that.Config.GetMap("cache"), HoTimeDBInterface(&that.Db), &that.Error) that.HoTimeCache = &cacheIns //mode生产模式开启的时候才开启数据库缓存,防止调试出问题 if that.Config.GetInt("mode") == 0 { that.Db.HoTimeCache = &cacheIns } } // SetConfig 设置配置文件路径全路径或者相对路径 func (that *Application) SetConfig(configPath ...string) { that.Log = GetLog("", true) that.Error = Error{Logger: that.Log} if len(configPath) != 0 { that.configPath = configPath[0] } if that.configPath == "" { that.configPath = "config/config.json" } //加载配置文件 btes, err := ioutil.ReadFile(that.configPath) that.Config = DeepCopyMap(Config).(Map) if err == nil { cmap := Map{} //文件是否损坏 cmap.JsonToMap(string(btes), &that.Error) for k, v := range cmap { that.Config[k] = v //程序配置 Config[k] = v //系统配置 } } else { that.Log.Error("配置文件不存在,或者配置出错,使用缺省默认配置") } that.Log = GetLog(that.Config.GetString("logFile"), true) that.Error = Error{Logger: that.Log} if that.Config.Get("webConnectLogShow") == nil || that.Config.GetBool("webConnectLogShow") { that.WebConnectLog = GetLog(that.Config.GetString("webConnectLogFile"), false) } //文件如果损坏则不写入配置防止配置文件数据丢失 if that.Error.GetError() == nil { //var configByte bytes.Buffer //判断配置文件是否序列有变化,有则修改配置,无则不变 //fmt.Println(len(btes)) configStr := that.Config.ToJsonString() if len(btes) != 0 && configStr == string(btes) { return } //写入配置说明 //var configNoteByte bytes.Buffer configNoteStr := ConfigNote.ToJsonString() //_ = json.Indent(&configNoteByte, []byte(ConfigNote.ToJsonString()), "", "\t") _ = os.MkdirAll(filepath.Dir(that.configPath), os.ModeDir) err = ioutil.WriteFile(that.configPath, []byte(configStr), os.ModePerm) if err != nil { that.Error.SetError(err) } _ = ioutil.WriteFile(filepath.Dir(that.configPath)+"/configNote.json", []byte(configNoteStr), os.ModePerm) } } // SetConnectListener 连接判断,返回true继续传输至控制层,false则停止传输 func (that *Application) SetConnectListener(lis func(this *Context) bool) { that.connectListener = append(that.connectListener, lis) } //网络错误 //func (this *Application) session(w http.ResponseWriter, req *http.Request) { // //} //序列化链接 func (that *Application) urlSer(url string) (string, []string) { q := strings.Index(url, "?") if q == -1 { q = len(url) } o := Substr(url, 0, q) r := strings.SplitN(o, "/", -1) var s = make([]string, 0) for i := 0; i < len(r); i++ { if !strings.EqualFold("", r[i]) { s = append(s, r[i]) } } return o, s } //访问 func (that *Application) handler(w http.ResponseWriter, req *http.Request) { nowUnixTime := time.Now() _, s := that.urlSer(req.RequestURI) //获取cookie // 如果cookie存在直接将sessionId赋值为cookie.Value // 如果cookie不存在就查找传入的参数中是否有token // 如果token不存在就生成随机的sessionId // 如果token存在就判断token是否在Session中有保存 // 如果有取出token并复制给cookie // 没有保存就生成随机的session cookie, err := req.Cookie(that.Config.GetString("sessionName")) sessionId := Md5(strconv.Itoa(Rand(10))) needSetCookie := "" token := req.Header.Get("Authorization") if len(token) != 32 { token = req.FormValue("token") } //没有cookie或者cookie不等于token //有token优先token if len(token) == 32 { sessionId = token //没有token,则查阅session } else if err == nil && cookie.Value != "" { sessionId = cookie.Value //session也没有则判断是否创建cookie } else { needSetCookie = sessionId } unescapeUrl, err := url.QueryUnescape(req.RequestURI) if err != nil { unescapeUrl = req.RequestURI } //访问实例 context := Context{SessionIns: SessionIns{SessionId: sessionId, HoTimeCache: that.HoTimeCache}, Resp: w, Req: req, Application: that, RouterString: s, Config: that.Config, Db: &that.Db, HandlerStr: unescapeUrl} //header默认设置 header := w.Header() header.Set("Content-Type", "text/html; charset=utf-8") //url去掉参数并序列化 context.HandlerStr, context.RouterString = that.urlSer(context.HandlerStr) //跨域设置 that.crossDomain(&context, needSetCookie) defer func() { //是否展示日志 if that.WebConnectLog != nil { ipStr := Substr(context.Req.RemoteAddr, 0, strings.Index(context.Req.RemoteAddr, ":")) //负载均衡优化 if ipStr == "127.0.0.1" { if req.Header.Get("X-Forwarded-For") != "" { ipStr = req.Header.Get("X-Forwarded-For") } else if req.Header.Get("X-Real-IP") != "" { ipStr = req.Header.Get("X-Real-IP") } } that.WebConnectLog.Infoln(ipStr, context.Req.Method, "time cost:", ObjToFloat64(time.Now().UnixNano()-nowUnixTime.UnixNano())/1000000.00, "ms", "data length:", ObjToFloat64(context.DataSize)/1000.00, "KB", context.HandlerStr) } }() //访问拦截true继续false暂停 connectListenerLen := len(that.connectListener) if connectListenerLen != 0 { for i := 0; i < connectListenerLen; i++ { if !that.connectListener[i](&context) { context.View() return } } } //接口服务 //验证接口严格模式 modeRouterStrict := that.Config.GetBool("modeRouterStrict") tempHandlerStr := context.HandlerStr if !modeRouterStrict { tempHandlerStr = strings.ToLower(tempHandlerStr) } //执行接口 if that.MethodRouter[tempHandlerStr] != nil { that.MethodRouter[tempHandlerStr](&context) context.View() return } //url赋值 path := that.Config.GetString("tpt") + tempHandlerStr //判断是否为默认 if path[len(path)-1] == '/' { defFile := that.Config.GetSlice("defFile") for i := 0; i < len(defFile); i++ { temp := path + defFile.GetString(i) _, err := os.Stat(temp) if err == nil { path = temp break } } if path[len(path)-1] == '/' { w.WriteHeader(404) return } } if strings.Contains(path, "/.") { w.WriteHeader(404) return } //设置header delete(header, "Content-Type") if that.Config.GetInt("mode") == 0 { header.Set("Cache-Control", "public") } else { header.Set("Cache-Control", "no-cache") } t := strings.LastIndex(path, ".") if t != -1 { tt := path[t:] if MimeMaps[tt] != "" { header.Add("Content-Type", MimeMaps[tt]) } } //w.Write(data) http.ServeFile(w, req, path) } func (that *Application) crossDomain(context *Context, sessionId string) { //没有跨域设置 if context.Config.GetString("crossDomain") == "" { if sessionId != "" { http.SetCookie(context.Resp, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) } return } header := context.Resp.Header() //不跨域,则不设置 remoteHost := context.Req.Host if context.Config.GetString("port") == "80" || context.Config.GetString("port") == "443" { remoteHost = remoteHost + ":" + context.Config.GetString("port") } if context.Config.GetString("crossDomain") != "auto" { //不跨域,则不设置 if strings.Contains(context.Config.GetString("crossDomain"), remoteHost) { if sessionId != "" { http.SetCookie(context.Resp, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) } return } header.Set("Access-Control-Allow-Origin", that.Config.GetString("crossDomain")) // 后端设置,2592000单位秒,这里是30天 header.Set("Access-Control-Max-Age", "2592000") //header.Set("Access-Control-Allow-Origin", "*") header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,PUT,DELETE") header.Set("Access-Control-Allow-Credentials", "true") header.Set("Access-Control-Expose-Headers", "*") header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token") if sessionId != "" { //跨域允许需要设置cookie的允许跨域https才有效果 context.Resp.Header().Set("Set-Cookie", that.Config.GetString("sessionName")+"="+sessionId+"; Path=/; SameSite=None; Secure") } return } origin := context.Req.Header.Get("Origin") refer := context.Req.Header.Get("Referer") if (origin != "" && strings.Contains(origin, remoteHost)) || strings.Contains(refer, remoteHost) { if sessionId != "" { http.SetCookie(context.Resp, &http.Cookie{Name: that.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) } return } if origin != "" { header.Set("Access-Control-Allow-Origin", origin) //return } else if refer != "" { tempInt := 0 lastInt := strings.IndexFunc(refer, func(r rune) bool { if r == '/' && tempInt > 8 { return true } tempInt++ return false }) if lastInt < 0 { lastInt = len(refer) } refer = Substr(refer, 0, lastInt) header.Set("Access-Control-Allow-Origin", refer) //header.Set("Access-Control-Allow-Origin", "*") } header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,PUT,DELETE") header.Set("Access-Control-Allow-Credentials", "true") header.Set("Access-Control-Expose-Headers", "*") header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token") if sessionId != "" { //跨域允许需要设置cookie的允许跨域https才有效果 context.Resp.Header().Set("Set-Cookie", that.Config.GetString("sessionName")+"="+sessionId+"; Path=/; SameSite=None; Secure") } } //Init 初始化application func Init(config string) Application { appIns := Application{} //手动模式, appIns.SetConfig(config) SetDB(&appIns) appIns.SetCache() appIns.MakeCode = &code.MakeCode{} codeConfig := appIns.Config.GetMap("codeConfig") appIns.MakeCodeRouter = Router{} if codeConfig != nil { for k, _ := range codeConfig { if appIns.Config.GetInt("mode") == 2 { appIns.MakeCode.Db2JSON(k, codeConfig.GetString(k), &appIns.Db, true) appIns.MakeCodeRouter[k] = Proj{} } else if appIns.Config.GetInt("mode") == 3 { appIns.MakeCode.Db2JSON(k, codeConfig.GetString(k), &appIns.Db, false) appIns.MakeCodeRouter[k] = Proj{} } else { appIns.MakeCode.Db2JSON(k, codeConfig.GetString(k), nil, false) appIns.MakeCodeRouter[k] = Proj{} } //接入动态代码层 if appIns.Router == nil { appIns.Router = Router{} } appIns.Router[k] = TptProject for k1, _ := range appIns.MakeCode.TableColumns { appIns.Router[k][k1] = appIns.Router[k]["hotimeCommon"] } setMakeCodeLintener(k, &appIns) } } return appIns } // SetDB 智能数据库设置 func SetDB(appIns *Application) { db := appIns.Config.GetMap("db") dbSqlite := db.GetMap("sqlite") dbMysql := db.GetMap("mysql") if db != nil && dbSqlite != nil { SetSqliteDB(appIns, dbSqlite) } if db != nil && dbMysql != nil { SetMysqlDB(appIns, dbMysql) } } func SetMysqlDB(appIns *Application, config Map) { appIns.Db.Type = "mysql" appIns.Db.DBName = config.GetString("name") appIns.Db.Prefix = config.GetString("prefix") appIns.SetConnectDB(func(err ...*Error) (master, slave *sql.DB) { //master数据库配置 query := config.GetString("user") + ":" + config.GetString("password") + "@tcp(" + config.GetString("host") + ":" + config.GetString("port") + ")/" + config.GetString("name") + "?charset=utf8" DB, e := sql.Open("mysql", query) if e != nil && len(err) != 0 { err[0].SetError(e) } master = DB //slave数据库配置 configSlave := config.GetMap("slave") if configSlave != nil { query := configSlave.GetString("user") + ":" + configSlave.GetString("password") + "@tcp(" + config.GetString("host") + ":" + configSlave.GetString("port") + ")/" + configSlave.GetString("name") + "?charset=utf8" DB1, e := sql.Open("mysql", query) if e != nil && len(err) != 0 { err[0].SetError(e) } slave = DB1 } return master, slave //return DB }) } func SetSqliteDB(appIns *Application, config Map) { appIns.Db.Type = "sqlite" appIns.Db.Prefix = config.GetString("prefix") appIns.SetConnectDB(func(err ...*Error) (master, slave *sql.DB) { db, e := sql.Open("sqlite3", config.GetString("path")) if e != nil && len(err) != 0 { err[0].SetError(e) } master = db return master, slave }) } func setMakeCodeLintener(name string, appIns *Application) { appIns.SetConnectListener(func(context *Context) bool { if len(context.RouterString) < 2 || appIns.MakeCodeRouter[context.RouterString[0]] == nil { return true } if len(context.RouterString) > 1 && context.RouterString[0] == name { if context.RouterString[1] == "hotime" && context.RouterString[2] == "login" { return true } if context.RouterString[1] == "hotime" && context.RouterString[2] == "logout" { return true } if context.Session(name+"_id").Data == nil { context.Display(2, "你还没有登录") return false } } //文件上传接口 if len(context.RouterString) == 1 && context.RouterString[0] == "file" && context.Req.Method == "POST" { if context.Session(name+"_id").Data == nil { context.Display(2, "你还没有登录") return false } //读取网络文件 fi, fheader, err := context.Req.FormFile("file") if err != nil { context.Display(3, err) return false } filePath := context.Config.GetString("filePath") if filePath == "" { filePath = "file/2006/01/02/" } path := time.Now().Format(filePath) e := os.MkdirAll(context.Config.GetString("tpt")+"/"+path, os.ModeDir) if e != nil { context.Display(3, e) return false } filePath = path + Md5(ObjToStr(RandX(100000, 9999999))) + fheader.Filename[strings.LastIndex(fheader.Filename, "."):] newFile, e := os.Create(context.Config.GetString("tpt") + "/" + filePath) if e != nil { context.Display(3, e) return false } _, e = io.Copy(newFile, fi) if e != nil { context.Display(3, e) return false } context.Display(0, filePath) return false } if len(context.RouterString) < 2 || len(context.RouterString) > 3 || !(context.Router[context.RouterString[0]] != nil && context.Router[context.RouterString[0]][context.RouterString[1]] != nil) { return true } //排除无效操作 if len(context.RouterString) == 2 && context.Req.Method != "GET" && context.Req.Method != "POST" { return true } //列表检索 if len(context.RouterString) == 2 && context.Req.Method == "GET" { if context.Router[context.RouterString[0]][context.RouterString[1]]["search"] == nil { return true } context.Router[context.RouterString[0]][context.RouterString[1]]["search"](context) } //新建 if len(context.RouterString) == 2 && context.Req.Method == "POST" { if context.Router[context.RouterString[0]][context.RouterString[1]]["add"] == nil { return true } context.Router[context.RouterString[0]][context.RouterString[1]]["add"](context) } if len(context.RouterString) == 3 && context.Req.Method == "POST" { return true } //查询单条 if len(context.RouterString) == 3 && context.Req.Method == "GET" { if context.Router[context.RouterString[0]][context.RouterString[1]]["info"] == nil { return true } context.Router[context.RouterString[0]][context.RouterString[1]]["info"](context) } //更新 if len(context.RouterString) == 3 && context.Req.Method == "PUT" { if context.Router[context.RouterString[0]][context.RouterString[1]]["update"] == nil { return true } context.Router[context.RouterString[0]][context.RouterString[1]]["update"](context) } //移除 if len(context.RouterString) == 3 && context.Req.Method == "DELETE" { if context.Router[context.RouterString[0]][context.RouterString[1]]["remove"] == nil { return true } context.Router[context.RouterString[0]][context.RouterString[1]]["remove"](context) } context.View() return false }) }