package hotime import ( "bytes" "database/sql" "encoding/json" "io/ioutil" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" ) type Application struct { MethodRouter Router contextBase Port string //端口号 TLSPort string //ssl访问端口号 connectListener []func(this *Context) bool //所有的访问监听,true按原计划继续使用,false表示有监听器处理 connectDbFunc func(err ...*Error) *sql.DB configPath string Config Map Db HoTimeDB Server *http.Server CacheIns sessionLong CacheIns sessionShort CacheIns http.Handler } func (this *Application) ServeHTTP(w http.ResponseWriter, req *http.Request) { this.handler(w, req) } //启动实例 func (this *Application) Run(router Router) { //如果没有设置配置自动生成配置 if this.configPath == "" || len(this.Config) == 0 { this.SetConfig() } //防止手动设置缓存误伤 if this.CacheIns == nil { this.SetCache(CacheIns(&CacheMemory{})) } //防止手动设置session误伤 if this.sessionShort == nil && this.sessionLong == nil { if this.connectDbFunc == nil { this.SetSession(CacheIns(&CacheMemory{}), nil) } else { this.SetSession(CacheIns(&CacheMemory{}), CacheIns(&CacheDb{Db: &this.Db, Time: this.Config.GetInt64("cacheLongTime")})) } } this.Router = router //重新设置MethodRouter//直达路由 this.MethodRouter = MethodRouter{} modeRouterStrict := true if this.Config.Get("modeRouterStrict").(bool) == false { modeRouterStrict = false } if router != nil { for pk, pv := range router { if !modeRouterStrict { pk = strings.ToLower(pk) } if pv != nil { for ck, cv := range pv { if !modeRouterStrict { ck = strings.ToLower(ck) } if cv != nil { for mk, mv := range cv { if !modeRouterStrict { mk = strings.ToLower(mk) } this.MethodRouter["/"+pk+"/"+ck+"/"+mk] = mv } } } } } } //this.Port = port this.Port = this.Config.GetString("port") this.TLSPort = this.Config.GetString("tlsPort") if this.connectDbFunc != nil && (this.Db.DB == nil || this.Db.DB.Ping() != nil) { this.Db.SetConnect(this.connectDbFunc) } if this.CacheIns == nil { this.CacheIns = CacheIns(&CacheMemory{Map: Map{}, Time: this.Config.GetInt64("cacheShortTime")}) } //异常处理 defer func() { if err := recover(); err != nil { //this.SetError(errors.New(fmt.Sprint(err)), LOG_FMT) logFmt(err, 2, LOG_ERROR) this.Run(router) } }() this.Server = &http.Server{} if !IsRun { IsRun = true } ch := make(chan int) if ObjToCeilInt(this.Port) != 0 { go func() { App[this.Port] = this this.Server.Handler = this //启动服务 this.Server.Addr = ":" + this.Port err := this.Server.ListenAndServe() logFmt(err, 2) ch <- 1 }() } else if ObjToCeilInt(this.TLSPort) != 0 { go func() { App[this.TLSPort] = this this.Server.Handler = this //启动服务 this.Server.Addr = ":" + this.TLSPort err := this.Server.ListenAndServeTLS(this.Config.GetString("tlsCert"), this.Config.GetString("tlsKey")) logFmt(err, 2) ch <- 2 }() } else { logFmt("没有端口启用", 2, LOG_INFO) return } value := <-ch logFmt("启动服务失败 : "+ObjToStr(value), 2, LOG_ERROR) } //启动实例 func (this *Application) SetConnectDB(connect func(err ...*Error) *sql.DB) { //this.Db.DBCached=false //if this.Config.GetCeilInt("dbCached")!=0{ // this.Db.DBCached=true //} this.connectDbFunc = connect this.Db.SetConnect(this.connectDbFunc) this.Db.DBCached = false if this.Config.GetCeilInt("dbCached") != 0 { this.Db.DBCached = true } this.Db.Type = this.Config.GetString("dbType") } //设置配置文件路径全路径或者相对路径 func (this *Application) SetSession(short CacheIns, Long CacheIns) { this.sessionLong = Long this.sessionShort = short } //默认配置缓存和session实现 func (this *Application) SetDefault(connect func(err ...*Error) *sql.DB) { this.SetConfig() if connect != nil { this.connectDbFunc = connect this.Db.SetConnect(this.connectDbFunc) } } //设置配置文件路径全路径或者相对路径 func (this *Application) SetCache(cache CacheIns) { this.CacheIns = cache } //设置配置文件路径全路径或者相对路径 func (this *Application) SetConfig(configPath ...string) { if len(configPath) != 0 { this.configPath = configPath[0] } if this.configPath == "" { this.configPath = "config/config.json" } //加载配置文件 btes, err := ioutil.ReadFile(this.configPath) this.Config = DeepCopyMap(Config).(Map) if err == nil { cmap := Map{} //文件是否损坏 cmap.JsonToMap(string(btes), &this.Error) for k, v := range cmap { this.Config[k] = v //程序配置 Config[k] = v //系统配置 } } else { logFmt("配置文件不存在,或者配置出错,使用缺省默认配置", 2) } //文件如果损坏则不写入配置防止配置文件数据丢失 if this.Error.GetError() == nil { var configByte bytes.Buffer err = json.Indent(&configByte, []byte(this.Config.ToJsonString()), "", "\t") //判断配置文件是否序列有变化有则修改配置,五则不变 //fmt.Println(len(btes)) if len(btes) != 0 && configByte.String() == string(btes) { return } //写入配置说明 var configNoteByte bytes.Buffer json.Indent(&configNoteByte, []byte(ConfigNote.ToJsonString()), "", "\t") os.MkdirAll(filepath.Dir(this.configPath), os.ModeDir) err = ioutil.WriteFile(this.configPath, configByte.Bytes(), os.ModeAppend) if err != nil { this.Error.SetError(err) } ioutil.WriteFile(filepath.Dir(this.configPath)+"/confignote.json", configNoteByte.Bytes(), os.ModeAppend) } } //连接判断,返回true继续传输至控制层,false则停止传输 func (this *Application) SetConnectListener(lis func(this *Context) bool) { this.connectListener = append(this.connectListener, lis) } //网络错误 //func (this *Application) session(w http.ResponseWriter, req *http.Request) { // //} //序列化链接 func (this *Application) urlSer(url string) (string, []string) { q := strings.Index(url, "?") if q == -1 { q = len(url) } o := Substr(url, 0, q) r := strings.SplitN(o, "/", -1) var s = make([]string, 0) for i := 0; i < len(r); i++ { if !strings.EqualFold("", r[i]) { s = append(s, r[i]) } } return o, s } //访问 func (this *Application) handler(w http.ResponseWriter, req *http.Request) { _, s := this.urlSer(req.RequestURI) //获取cookie // 如果cookie存在直接将sessionId赋值为cookie.Value // 如果cookie不存在就查找传入的参数中是否有token // 如果token不存在就生成随机的sessionId // 如果token存在就判断token是否在Session中有保存 // 如果有取出token并复制给cookie // 没有保存就生成随机的session cookie, err := req.Cookie(this.Config.GetString("sessionName")) sessionId := Md5(strconv.Itoa(Rand(10))) token := req.FormValue("token") //isFirst:=false if err != nil || (len(token) == 32 && cookie.Value != token) { if len(token) == 32 { sessionId = token } //else{ // isFirst=true; //} http.SetCookie(w, &http.Cookie{Name: this.Config.GetString("sessionName"), Value: sessionId, Path: "/"}) } else { sessionId = cookie.Value } unescapeUrl, err := url.QueryUnescape(req.RequestURI) if err != nil { unescapeUrl = req.RequestURI } //访问实例 context := Context{SessionIns: SessionIns{SessionId: sessionId, LongCache: this.sessionLong, ShortCache: this.sessionShort, }, CacheIns: this.CacheIns, Resp: w, Req: req, Application: this, RouterString: s, Config: this.Config, Db: &this.Db, HandlerStr: unescapeUrl} //header默认设置 header := w.Header() header.Set("Content-Type", "text/html; charset=utf-8") //url去掉参数并序列化 context.HandlerStr, context.RouterString = this.urlSer(context.HandlerStr) //跨域设置 this.crossDomain(&context) //是否展示日志 if this.Config.GetInt("connectLogShow") != 0 { logFmt(Substr(context.Req.RemoteAddr, 0, strings.Index(context.Req.RemoteAddr, ":"))+" "+context.HandlerStr, 0, LOG_INFO) } //访问拦截true继续false暂停 connectListenerLen := len(this.connectListener) if connectListenerLen != 0 { for i := 0; i < connectListenerLen; i++ { if !this.connectListener[i](&context) { context.View() return } } } //接口服务 //if len(s) == 3 { // //如果满足规则则路由到对应控制器去 // if this.Router[s[0]] != nil && this.Router[s[0]][s[1]] != nil && this.Router[s[0]][s[1]][s[2]] != nil { // //控制层 // this.Router[s[0]][s[1]][s[2]](&context) // //header.Set("Content-Type", "text/html; charset=utf-8") // context.View() // return // } // //} //验证接口严格模式 modeRouterStrict := this.Config.Get("modeRouterStrict").(bool) tempHandlerStr := context.HandlerStr if !modeRouterStrict { tempHandlerStr = strings.ToLower(tempHandlerStr) } //执行接口 if this.MethodRouter[tempHandlerStr] != nil { this.MethodRouter[tempHandlerStr](&context) context.View() return } //url赋值 path := this.Config.GetString("tpt") + tempHandlerStr //判断是否为默认 if path[len(path)-1] == '/' { defFile := this.Config.GetSlice("defFile") for i := 0; i < len(defFile); i++ { temp := path + defFile.GetString(i) _, err := os.Stat(temp) if err == nil { path = temp break } } if path[len(path)-1] == '/' { w.WriteHeader(404) return } } if strings.Contains(path, "/.") { w.WriteHeader(404) return } //设置header delete(header, "Content-Type") if this.Config.GetInt("debug") != 1 { header.Set("Cache-Control", "public") } if strings.Index(path, ".m3u8") != -1 { header.Add("Content-Type", "audio/mpegurl") } //w.Write(data) http.ServeFile(w, req, path) } func (this *Application) crossDomain(context *Context) { //没有跨域设置 if context.Config.GetString("crossDomain") == "" { return } header := context.Resp.Header() header.Set("Access-Control-Allow-Origin", "*") header.Set("Access-Control-Allow-Methods", "*") header.Set("Access-Control-Allow-Credentials", "true") header.Set("Access-Control-Expose-Headers", "*") header.Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Access-Token") if context.Config.GetString("crossDomain") != "*" { header.Set("Access-Control-Allow-Origin", this.Config.GetString("crossDomain")) return } origin := context.Req.Header.Get("Origin") if origin != "" { header.Set("Access-Control-Allow-Origin", origin) return } refer := context.Req.Header.Get("Referer") if refer != "" { tempInt := 0 lastInt := strings.IndexFunc(refer, func(r rune) bool { if r == '/' && tempInt > 8 { return true } tempInt++ return false }) if lastInt < 0 { lastInt = len(refer) } refer = Substr(refer, 0, lastInt) header.Set("Access-Control-Allow-Origin", refer) } }